diff --git a/docs/guides/code_examples/storages/rq_basic_example.py b/docs/guides/code_examples/storages/rq_basic_example.py index 9e983bb9fe..388c184fc6 100644 --- a/docs/guides/code_examples/storages/rq_basic_example.py +++ b/docs/guides/code_examples/storages/rq_basic_example.py @@ -12,7 +12,7 @@ async def main() -> None: await request_queue.add_request('https://apify.com/') # Add multiple requests as a batch. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://crawlee.dev/', 'https://crawlee.dev/python/'] ) diff --git a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py index 21bedad0b9..4ef61efc82 100644 --- a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py @@ -10,7 +10,7 @@ async def main() -> None: request_queue = await RequestQueue.open(name='my-request-queue') # Interact with the request queue directly, e.g. add a batch of requests. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://apify.com/', 'https://crawlee.dev/'] ) diff --git a/pyproject.toml b/pyproject.toml index 52fa636dbb..4abfd3d64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,8 +93,8 @@ crawlee = "crawlee._cli:cli" [dependency-groups] dev = [ - "apify_client", # For e2e tests. - "build~=1.2.2", # For e2e tests. + "apify-client", # For e2e tests. + "build~=1.2.2", # For e2e tests. "mypy~=1.15.0", "pre-commit~=4.2.0", "proxy-py~=2.4.0", @@ -105,7 +105,7 @@ dev = [ "pytest-xdist~=3.6.0", "pytest~=8.3.0", "ruff~=0.11.0", - "setuptools~=79.0.0", # setuptools are used by pytest, but not explicitly required + "setuptools", # setuptools are used by pytest, but not explicitly required "sortedcontainers-stubs~=2.4.0", "types-beautifulsoup4~=4.12.0.20240229", "types-cachetools~=5.5.0.20240820", @@ -144,7 +144,6 @@ ignore = [ "PLR0911", # Too many return statements "PLR0913", # Too many arguments in function definition "PLR0915", # Too many statements - "PTH", # flake8-use-pathlib "PYI034", # `__aenter__` methods in classes like `{name}` usually return `self` at runtime "PYI036", # The second argument in `__aexit__` should be annotated with `object` or `BaseException | None` "S102", # Use of `exec` detected @@ -166,6 +165,7 @@ indent-style = "space" "F401", # Unused imports ] "**/{tests}/*" = [ + "ASYNC230", # Async functions should not open files with blocking methods like `open` "D", # Everything from the pydocstyle "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable diff --git a/src/crawlee/_cli.py b/src/crawlee/_cli.py index 689cb7b182..d8b295b5ed 100644 --- a/src/crawlee/_cli.py +++ b/src/crawlee/_cli.py @@ -22,7 +22,7 @@ cli = typer.Typer(no_args_is_help=True) template_directory = importlib.resources.files('crawlee') / 'project_template' -with open(str(template_directory / 'cookiecutter.json')) as f: +with (template_directory / 'cookiecutter.json').open() as f: cookiecutter_json = json.load(f) crawler_choices = cookiecutter_json['crawler_type'] diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py index 31bc36c63c..2cb8f8302a 100644 --- a/src/crawlee/_service_locator.py +++ b/src/crawlee/_service_locator.py @@ -3,8 +3,8 @@ from crawlee._utils.docs import docs_group from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError -from crawlee.events import EventManager -from crawlee.storage_clients import StorageClient +from crawlee.events import EventManager, LocalEventManager +from crawlee.storage_clients import FileSystemStorageClient, StorageClient @docs_group('Classes') @@ -49,8 +49,6 @@ def set_configuration(self, configuration: Configuration) -> None: def get_event_manager(self) -> EventManager: """Get the event manager.""" if self._event_manager is None: - from crawlee.events import LocalEventManager - self._event_manager = ( LocalEventManager().from_config(config=self._configuration) if self._configuration @@ -77,13 +75,7 @@ def set_event_manager(self, event_manager: EventManager) -> None: def get_storage_client(self) -> StorageClient: """Get the storage client.""" if self._storage_client is None: - from crawlee.storage_clients import MemoryStorageClient - - self._storage_client = ( - MemoryStorageClient.from_config(config=self._configuration) - if self._configuration - else MemoryStorageClient.from_config() - ) + self._storage_client = FileSystemStorageClient() self._storage_client_was_retrieved = True return self._storage_client diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index c68ae63df9..9b6cb0f2e7 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -275,10 +275,6 @@ async def push_data( **kwargs: Unpack[PushDataKwargs], ) -> None: """Track a call to the `push_data` context helper.""" - from crawlee.storages._dataset import Dataset - - await Dataset.check_and_serialize(data) - self.push_data_calls.append( PushDataFunctionCall( data=data, diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index 022d0604ef..4de6804490 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -2,18 +2,26 @@ import asyncio import contextlib -import io +import csv import json import mimetypes import os import re import shutil from enum import Enum +from logging import getLogger from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import AsyncIterator from pathlib import Path - from typing import Any + from typing import Any, TextIO + + from typing_extensions import Unpack + + from crawlee.storages._types import ExportDataCsvKwargs, ExportDataJsonKwargs + +logger = getLogger(__name__) class ContentType(Enum): @@ -83,28 +91,67 @@ def determine_file_extension(content_type: str) -> str | None: return ext[1:] if ext is not None else ext -def is_file_or_bytes(value: Any) -> bool: - """Determine if the input value is a file-like object or bytes. - - This function checks whether the provided value is an instance of bytes, bytearray, or io.IOBase (file-like). - The method is simplified for common use cases and may not cover all edge cases. +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. Args: - value: The value to be checked. + obj: The object to serialize. Returns: - True if the value is either a file-like object or bytes, False otherwise. + A string containing the JSON representation of the input object. """ - return isinstance(value, (bytes, bytearray, io.IOBase)) + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) -async def json_dumps(obj: Any) -> str: - """Serialize an object to a JSON-formatted string with specific settings. +def infer_mime_type(value: Any) -> str: + """Infer the MIME content type from the value. Args: - obj: The object to serialize. + value: The value to infer the content type from. Returns: - A string containing the JSON representation of the input object. + The inferred MIME content type. """ - return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) + # If the value is bytes (or bytearray), return binary content type. + if isinstance(value, (bytes, bytearray)): + return 'application/octet-stream' + + # If the value is a dict or list, assume JSON. + if isinstance(value, (dict, list)): + return 'application/json; charset=utf-8' + + # If the value is a string, assume plain text. + if isinstance(value, str): + return 'text/plain; charset=utf-8' + + # Default fallback. + return 'application/octet-stream' + + +async def export_json_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataJsonKwargs], +) -> None: + items = [item async for item in iterator] + json.dump(items, dst, **kwargs) + + +async def export_csv_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataCsvKwargs], +) -> None: + writer = csv.writer(dst, **kwargs) + write_header = True + + # Iterate over the dataset and write to CSV. + async for item in iterator: + if not item: + continue + + if write_header: + writer.writerow(item.keys()) + write_header = False + + writer.writerow(item.values()) diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index de22118816..e3ef39f486 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -118,21 +118,7 @@ class Configuration(BaseSettings): ) ), ] = True - """Whether to purge the storage on the start. This option is utilized by the `MemoryStorageClient`.""" - - write_metadata: Annotated[bool, Field(alias='crawlee_write_metadata')] = True - """Whether to write the storage metadata. This option is utilized by the `MemoryStorageClient`.""" - - persist_storage: Annotated[ - bool, - Field( - validation_alias=AliasChoices( - 'apify_persist_storage', - 'crawlee_persist_storage', - ) - ), - ] = True - """Whether to persist the storage. This option is utilized by the `MemoryStorageClient`.""" + """Whether to purge the storage on the start. This option is utilized by the storage clients.""" persist_state_interval: Annotated[ timedelta_ms, @@ -239,7 +225,7 @@ class Configuration(BaseSettings): ), ), ] = './storage' - """The path to the storage directory. This option is utilized by the `MemoryStorageClient`.""" + """The path to the storage directory. This option is utilized by the storage clients.""" headless: Annotated[ bool, diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 3573033149..37497d2bd2 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -34,6 +34,7 @@ SendRequestFunction, ) from crawlee._utils.docs import docs_group +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -60,7 +61,7 @@ import re from contextlib import AbstractAsyncContextManager - from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable + from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.events import EventManager from crawlee.http_clients import HttpClient, HttpResponse @@ -70,7 +71,7 @@ from crawlee.statistics import FinalStatistics from crawlee.storage_clients import StorageClient from crawlee.storage_clients.models import DatasetItemsListPage - from crawlee.storages._dataset import ExportDataCsvKwargs, ExportDataJsonKwargs, GetDataKwargs, PushDataKwargs + from crawlee.storages._types import GetDataKwargs TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext) TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) @@ -676,7 +677,7 @@ async def add_requests( request_manager = await self.get_request_manager() - await request_manager.add_requests_batched( + await request_manager.add_requests( requests=allowed_requests, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, @@ -684,13 +685,16 @@ async def add_requests( wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]: - store = await self.get_key_value_store() - return await store.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) + async def _use_state( + self, + default_value: dict[str, JsonSerializable] | None = None, + ) -> dict[str, JsonSerializable]: + # TODO: implement + return {} async def _save_crawler_state(self) -> None: - store = await self.get_key_value_store() - await store.persist_autosaved_values() + pass + # TODO: implement async def get_data( self, @@ -720,78 +724,29 @@ async def export_data( dataset_id: str | None = None, dataset_name: str | None = None, ) -> None: - """Export data from a `Dataset`. + """Export all items from a Dataset to a JSON or CSV file. - This helper method simplifies the process of exporting data from a `Dataset`. It opens the specified - one and then exports the data based on the provided parameters. If you need to pass options - specific to the output format, use the `export_data_csv` or `export_data_json` method instead. + This method simplifies the process of exporting data collected during crawling. It automatically + determines the export format based on the file extension (`.json` or `.csv`) and handles + the conversion of `Dataset` items to the appropriate format. Args: - path: The destination path. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. + path: The destination file path. Must end with '.json' or '.csv'. + dataset_id: The ID of the Dataset to export from. If None, uses `name` parameter instead. + dataset_name: The name of the Dataset to export from. If None, uses `id` parameter instead. """ dataset = await self.get_dataset(id=dataset_id, name=dataset_name) path = path if isinstance(path, Path) else Path(path) - destination = path.open('w', newline='') + dst = path.open('w', newline='') if path.suffix == '.csv': - await dataset.write_to_csv(destination) + await export_csv_to_stream(dataset.iterate_items(), dst) elif path.suffix == '.json': - await dataset.write_to_json(destination) + await export_json_to_stream(dataset.iterate_items(), dst) else: raise ValueError(f'Unsupported file extension: {path.suffix}') - async def export_data_csv( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataCsvKwargs], - ) -> None: - """Export data from a `Dataset` to a CSV file. - - This helper method simplifies the process of exporting data from a `Dataset` in csv format. It opens - the specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path. - content_type: The output format. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in csv format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_csv(path.open('w', newline=''), **kwargs) - - async def export_data_json( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataJsonKwargs], - ) -> None: - """Export data from a `Dataset` to a JSON file. - - This helper method simplifies the process of exporting data from a `Dataset` in json format. It opens the - specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in json format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_json(path.open('w', newline=''), **kwargs) - async def _push_data( self, data: JsonSerializable, @@ -1050,7 +1005,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> ): requests.append(dst_request) - await request_manager.add_requests_batched(requests) + await request_manager.add_requests(requests) for push_data_call in result.push_data_calls: await self._push_data(**push_data_call) diff --git a/src/crawlee/fingerprint_suite/_browserforge_adapter.py b/src/crawlee/fingerprint_suite/_browserforge_adapter.py index d64ddd59f0..11f9f82d79 100644 --- a/src/crawlee/fingerprint_suite/_browserforge_adapter.py +++ b/src/crawlee/fingerprint_suite/_browserforge_adapter.py @@ -1,10 +1,10 @@ from __future__ import annotations -import os.path from collections.abc import Iterable from copy import deepcopy from functools import reduce from operator import or_ +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from browserforge.bayesian_network import extract_json @@ -253,9 +253,9 @@ def generate(self, browser_type: SupportedBrowserType = 'chromium') -> dict[str, def get_available_header_network() -> dict: """Get header network that contains possible header values.""" - if os.path.isfile(DATA_DIR / 'header-network.zip'): + if Path(DATA_DIR / 'header-network.zip').is_file(): return extract_json(DATA_DIR / 'header-network.zip') - if os.path.isfile(DATA_DIR / 'header-network-definition.zip'): + if Path(DATA_DIR / 'header-network-definition.zip').is_file(): return extract_json(DATA_DIR / 'header-network-definition.zip') raise FileNotFoundError('Missing header-network file.') diff --git a/src/crawlee/project_template/hooks/post_gen_project.py b/src/crawlee/project_template/hooks/post_gen_project.py index e076ff9308..c0495a724d 100644 --- a/src/crawlee/project_template/hooks/post_gen_project.py +++ b/src/crawlee/project_template/hooks/post_gen_project.py @@ -2,7 +2,6 @@ import subprocess from pathlib import Path - # % if cookiecutter.package_manager in ['poetry', 'uv'] Path('requirements.txt').unlink() @@ -32,8 +31,9 @@ # Install requirements and generate requirements.txt as an impromptu lockfile subprocess.check_call([str(path / 'pip'), 'install', '-r', 'requirements.txt']) -with open('requirements.txt', 'w') as requirements_txt: - subprocess.check_call([str(path / 'pip'), 'freeze'], stdout=requirements_txt) +Path('requirements.txt').write_text( + subprocess.check_output([str(path / 'pip'), 'freeze']).decode() +) # % if cookiecutter.crawler_type == 'playwright' subprocess.check_call([str(path / 'playwright'), 'install']) diff --git a/src/crawlee/request_loaders/_request_loader.py b/src/crawlee/request_loaders/_request_loader.py index e358306a45..2e3c8a3b73 100644 --- a/src/crawlee/request_loaders/_request_loader.py +++ b/src/crawlee/request_loaders/_request_loader.py @@ -25,10 +25,6 @@ class RequestLoader(ABC): - Managing state information such as the total and handled request counts. """ - @abstractmethod - async def get_total_count(self) -> int: - """Return an offline approximation of the total number of requests in the source (i.e. pending + handled).""" - @abstractmethod async def is_empty(self) -> bool: """Return True if there are no more requests in the source (there might still be unfinished requests).""" @@ -45,10 +41,6 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after a successful processing (or after giving up retrying).""" - @abstractmethod - async def get_handled_count(self) -> int: - """Return the number of handled requests.""" - async def to_tandem(self, request_manager: RequestManager | None = None) -> RequestManagerTandem: """Combine the loader with a request manager to support adding and reclaiming requests. diff --git a/src/crawlee/request_loaders/_request_manager.py b/src/crawlee/request_loaders/_request_manager.py index f63f962cb9..5a8427c2cb 100644 --- a/src/crawlee/request_loaders/_request_manager.py +++ b/src/crawlee/request_loaders/_request_manager.py @@ -6,12 +6,12 @@ from crawlee._utils.docs import docs_group from crawlee.request_loaders._request_loader import RequestLoader +from crawlee.storage_clients.models import ProcessedRequest if TYPE_CHECKING: from collections.abc import Sequence from crawlee._request import Request - from crawlee.storage_clients.models import ProcessedRequest @docs_group('Abstract classes') @@ -40,10 +40,11 @@ async def add_request( Information about the request addition to the manager. """ - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, # noqa: ARG002 wait_time_between_batches: timedelta = timedelta(seconds=1), # noqa: ARG002 wait_for_all_requests_to_be_added: bool = False, # noqa: ARG002 @@ -53,14 +54,17 @@ async def add_requests_batched( Args: requests: Requests to enqueue. + forefront: If True, add requests to the beginning of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. wait_for_all_requests_to_be_added_timeout: Timeout for waiting for all requests to be added. """ # Default and dumb implementation. + processed_requests = list[ProcessedRequest]() for request in requests: - await self.add_request(request) + processed_request = await self.add_request(request, forefront=forefront) + processed_requests.append(processed_request) @abstractmethod async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: diff --git a/src/crawlee/request_loaders/_request_manager_tandem.py b/src/crawlee/request_loaders/_request_manager_tandem.py index 9f0b8cefe8..5debdb7135 100644 --- a/src/crawlee/request_loaders/_request_manager_tandem.py +++ b/src/crawlee/request_loaders/_request_manager_tandem.py @@ -49,7 +49,7 @@ async def add_request(self, request: str | Request, *, forefront: bool = False) return await self._read_write_manager.add_request(request, forefront=forefront) @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, @@ -58,7 +58,7 @@ async def add_requests_batched( wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - return await self._read_write_manager.add_requests_batched( + return await self._read_write_manager.add_requests( requests, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 66d352d7a7..ce8c713ca9 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,4 +1,9 @@ from ._base import StorageClient +from ._file_system import FileSystemStorageClient from ._memory import MemoryStorageClient -__all__ = ['MemoryStorageClient', 'StorageClient'] +__all__ = [ + 'FileSystemStorageClient', + 'MemoryStorageClient', + 'StorageClient', +] diff --git a/src/crawlee/storage_clients/_apify/__init__.py b/src/crawlee/storage_clients/_apify/__init__.py new file mode 100644 index 0000000000..4af7c8ee23 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/__init__.py @@ -0,0 +1,11 @@ +from ._dataset_client import ApifyDatasetClient +from ._key_value_store_client import ApifyKeyValueStoreClient +from ._request_queue_client import ApifyRequestQueueClient +from ._storage_client import ApifyStorageClient + +__all__ = [ + 'ApifyDatasetClient', + 'ApifyKeyValueStoreClient', + 'ApifyRequestQueueClient', + 'ApifyStorageClient', +] diff --git a/src/crawlee/storage_clients/_apify/_dataset_client.py b/src/crawlee/storage_clients/_apify/_dataset_client.py new file mode 100644 index 0000000000..10cb47f028 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_dataset_client.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import asyncio +from logging import getLogger +from typing import TYPE_CHECKING, Any, ClassVar + +from apify_client import ApifyClientAsync +from typing_extensions import override + +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from datetime import datetime + + from apify_client.clients import DatasetClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyDatasetClient(DatasetClient): + """An Apify platform implementation of the dataset client.""" + + _cache_by_name: ClassVar[dict[str, ApifyDatasetClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, + api_client: DatasetClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + + self._api_client = api_client + """The Apify dataset client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> DatasetMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyDatasetClient: + default_name = configuration.default_dataset_id + token = 'configuration.apify_token' # TODO: use the real value + api_url = 'configuration.apify_api_url' # TODO: use the real value + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Otherwise, create a new one. + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_datasets_client = apify_client_async.datasets() + + metadata = DatasetMetadata.model_validate( + await apify_datasets_client.get_or_create(name=id if id is not None else name), + ) + + apify_dataset_client = apify_client_async.dataset(dataset_id=metadata.id) + + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + item_count=metadata.item_count, + api_client=apify_dataset_client, + ) + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + async with self._lock: + await self._api_client.push_items(items=data) + await self._update_metadata() + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + response = await self._api_client.list_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) + result = DatasetItemsListPage.model_validate(vars(response)) + await self._update_metadata() + return result + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + async for item in self._api_client.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ): + yield item + + await self._update_metadata() + + async def _update_metadata(self) -> None: + """Update the dataset metadata file with current information.""" + metadata = await self._api_client.get() + self._metadata = DatasetMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_key_value_store_client.py b/src/crawlee/storage_clients/_apify/_key_value_store_client.py new file mode 100644 index 0000000000..621a9d9fe2 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_key_value_store_client.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import asyncio +from logging import getLogger +from typing import TYPE_CHECKING, Any, ClassVar + +from apify_client import ApifyClientAsync +from typing_extensions import override +from yarl import URL + +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import ( + KeyValueStoreListKeysPage, + KeyValueStoreMetadata, + KeyValueStoreRecord, + KeyValueStoreRecordMetadata, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from datetime import datetime + + from apify_client.clients import KeyValueStoreClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyKeyValueStoreClient(KeyValueStoreClient): + """An Apify platform implementation of the key-value store client.""" + + _cache_by_name: ClassVar[dict[str, ApifyKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + api_client: KeyValueStoreClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + + self._api_client = api_client + """The Apify key-value store client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyKeyValueStoreClient: + default_name = configuration.default_key_value_store_id + token = 'configuration.apify_token' # TODO: use the real value + api_url = 'configuration.apify_api_url' # TODO: use the real value + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Otherwise, create a new one. + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_kvss_client = apify_client_async.key_value_stores() + + metadata = KeyValueStoreMetadata.model_validate( + await apify_kvss_client.get_or_create(name=id if id is not None else name), + ) + + apify_kvs_client = apify_client_async.key_value_store(key_value_store_id=metadata.id) + + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + api_client=apify_kvs_client, + ) + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def get_value(self, key: str) -> KeyValueStoreRecord | None: + response = await self._api_client.get_record(key) + record = KeyValueStoreRecord.model_validate(response) if response else None + await self._update_metadata() + return record + + @override + async def set_value(self, key: str, value: Any, content_type: str | None = None) -> None: + async with self._lock: + await self._api_client.set_record( + key=key, + value=value, + content_type=content_type, + ) + await self._update_metadata() + + @override + async def delete_value(self, key: str) -> None: + async with self._lock: + await self._api_client.delete_record(key=key) + await self._update_metadata() + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + count = 0 + + while True: + response = await self._api_client.list_keys(exclusive_start_key=exclusive_start_key) + list_key_page = KeyValueStoreListKeysPage.model_validate(response) + + for item in list_key_page.items: + yield item + count += 1 + + # If we've reached the limit, stop yielding + if limit and count >= limit: + break + + # If we've reached the limit or there are no more pages, exit the loop + if (limit and count >= limit) or not list_key_page.is_truncated: + break + + exclusive_start_key = list_key_page.next_exclusive_start_key + + await self._update_metadata() + + async def get_public_url(self, key: str) -> str: + """Get a URL for the given key that may be used to publicly access the value in the remote key-value store. + + Args: + key: The key for which the URL should be generated. + """ + if self._api_client.resource_id is None: + raise ValueError('resource_id cannot be None when generating a public URL') + + public_url = ( + URL(self._api_client.base_url) / 'v2' / 'key-value-stores' / self._api_client.resource_id / 'records' / key + ) + + key_value_store = self.metadata + + if key_value_store and isinstance(getattr(key_value_store, 'model_extra', None), dict): + url_signing_secret_key = key_value_store.model_extra.get('urlSigningSecretKey') + if url_signing_secret_key: + # Note: This would require importing create_hmac_signature from apify._crypto + # public_url = public_url.with_query(signature=create_hmac_signature(url_signing_secret_key, key)) + # For now, I'll leave this part commented as we may need to add the proper import + pass + + return str(public_url) + + async def _update_metadata(self) -> None: + """Update the key-value store metadata with current information.""" + metadata = await self._api_client.get() + self._metadata = KeyValueStoreMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_request_queue_client.py b/src/crawlee/storage_clients/_apify/_request_queue_client.py new file mode 100644 index 0000000000..d0f86041d2 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_request_queue_client.py @@ -0,0 +1,650 @@ +from __future__ import annotations + +import asyncio +import os +from collections import deque +from datetime import datetime, timedelta, timezone +from logging import getLogger +from typing import TYPE_CHECKING, ClassVar, Final + +from apify_client import ApifyClientAsync +from cachetools import LRUCache +from typing_extensions import override + +from crawlee import Request +from crawlee._utils.requests import unique_key_to_request_id +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import ( + AddRequestsResponse, + CachedRequest, + ProcessedRequest, + ProlongRequestLockResponse, + RequestQueueHead, + RequestQueueMetadata, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from apify_client.clients import RequestQueueClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyRequestQueueClient(RequestQueueClient): + """An Apify platform implementation of the request queue client.""" + + _cache_by_name: ClassVar[dict[str, ApifyRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" + + _DEFAULT_LOCK_TIME: Final[timedelta] = timedelta(minutes=3) + """The default lock time for requests in the queue.""" + + _MAX_CACHED_REQUESTS: Final[int] = 1_000_000 + """Maximum number of requests that can be cached.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + api_client: RequestQueueClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + self._api_client = api_client + """The Apify request queue client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + self._queue_head = deque[str]() + """A deque to store request IDs in the queue head.""" + + self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) + """A cache to store request objects.""" + + self._queue_has_locked_requests: bool | None = None + """Whether the queue has requests locked by another client.""" + + self._should_check_for_forefront_requests = False + """Whether to check for forefront requests in the next list_head call.""" + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyRequestQueueClient: + default_name = configuration.default_request_queue_id + + # Get API credentials + token = os.environ.get('APIFY_TOKEN') + api_url = 'https://api.apify.com' + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Create a new API client + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_rqs_client = apify_client_async.request_queues() + + # Get or create the request queue + metadata = RequestQueueMetadata.model_validate( + await apify_rqs_client.get_or_create(name=id if id is not None else name), + ) + + apify_rq_client = apify_client_async.request_queue(request_queue_id=metadata.id) + + # Create the client instance + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + had_multiple_clients=metadata.had_multiple_clients, + handled_request_count=metadata.handled_request_count, + pending_request_count=metadata.pending_request_count, + stats=metadata.stats, + total_request_count=metadata.total_request_count, + api_client=apify_rq_client, + ) + + # Cache the client by name + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + # Prepare requests for API by converting to dictionaries + requests_dict = [request.model_dump(by_alias=True) for request in requests] + + # Remove 'id' fields from requests as the API doesn't accept them + for request_dict in requests_dict: + if 'id' in request_dict: + del request_dict['id'] + + # Send requests to API + response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) + + # Update metadata after adding requests + await self._update_metadata() + + return AddRequestsResponse.model_validate(response) + + @override + async def get_request(self, request_id: str) -> Request | None: + """Get a request by ID. + + Args: + request_id: The ID of the request to get. + + Returns: + The request or None if not found. + """ + response = await self._api_client.get_request(request_id) + await self._update_metadata() + + if response is None: + return None + + return Request.model_validate(**response) + + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Returns: + The request or `None` if there are no more pending requests. + """ + # Ensure the queue head has requests if available + await self._ensure_head_is_non_empty() + + # If queue head is empty after ensuring, there are no requests + if not self._queue_head: + return None + + # Get the next request ID from the queue head + next_request_id = self._queue_head.popleft() + request = await self._get_or_hydrate_request(next_request_id) + + # Handle potential inconsistency where request might not be in the main table yet + if request is None: + logger.debug( + 'Cannot find a request from the beginning of queue, will be retried later', + extra={'nextRequestId': next_request_id}, + ) + return None + + # If the request was already handled, skip it + if request.handled_at is not None: + logger.debug( + 'Request fetched from the beginning of queue was already handled', + extra={'nextRequestId': next_request_id}, + ) + return None + + return request + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Set the handled_at timestamp if not already set + if request.handled_at is None: + request.handled_at = datetime.now(tz=timezone.utc) + + try: + # Update the request in the API + processed_request = await self._update_request(request) + processed_request.unique_key = request.unique_key + + # Update the cache with the handled request + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + processed_request, + forefront=False, + hydrated_request=request, + ) + + # Update metadata after marking request as handled + await self._update_metadata() + except Exception as exc: + logger.debug(f'Error marking request {request.id} as handled: {exc!s}') + return None + else: + return processed_request + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + try: + # Update the request in the API + processed_request = await self._update_request(request, forefront=forefront) + processed_request.unique_key = request.unique_key + + # Update the cache + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + processed_request, + forefront=forefront, + hydrated_request=request, + ) + + # If we're adding to the forefront, we need to check for forefront requests + # in the next list_head call + if forefront: + self._should_check_for_forefront_requests = True + + # Try to release the lock on the request + try: + await self._delete_request_lock(request.id, forefront=forefront) + except Exception as err: + logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) + + # Update metadata after reclaiming request + await self._update_metadata() + except Exception as exc: + logger.debug(f'Error reclaiming request {request.id}: {exc!s}') + return None + else: + return processed_request + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + head = await self._list_head(limit=1, lock_time=None) + return len(head.items) == 0 + + async def _ensure_head_is_non_empty(self) -> None: + """Ensure that the queue head has requests if they are available in the queue.""" + # If queue head has adequate requests, skip fetching more + if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: + return + + # Fetch requests from the API and populate the queue head + await self._list_head(lock_time=self._DEFAULT_LOCK_TIME) + + async def _get_or_hydrate_request(self, request_id: str) -> Request | None: + """Get a request by ID, either from cache or by fetching from API. + + Args: + request_id: The ID of the request to get. + + Returns: + The request if found and valid, otherwise None. + """ + # First check if the request is in our cache + cached_entry = self._requests_cache.get(request_id) + + if cached_entry and cached_entry.hydrated: + # If we have the request hydrated in cache, check if lock is expired + if cached_entry.lock_expires_at and cached_entry.lock_expires_at < datetime.now(tz=timezone.utc): + # Try to prolong the lock if it's expired + try: + lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) + response = await self._prolong_request_lock( + request_id, forefront=cached_entry.forefront, lock_secs=lock_secs + ) + cached_entry.lock_expires_at = response.lock_expires_at + except Exception: + # If prolonging the lock fails, we lost the request + logger.debug(f'Failed to prolong lock for request {request_id}, returning None') + return None + + return cached_entry.hydrated + + # If not in cache or not hydrated, fetch the request + try: + # Try to acquire or prolong the lock + lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) + await self._prolong_request_lock(request_id, forefront=False, lock_secs=lock_secs) + + # Fetch the request data + request = await self.get_request(request_id) + + # If request is not found, release lock and return None + if not request: + await self._delete_request_lock(request_id) + return None + + # Update cache with hydrated request + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + ProcessedRequest( + id=request_id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=request.handled_at is not None, + ), + forefront=False, + hydrated_request=request, + ) + except Exception as exc: + logger.debug(f'Error fetching or locking request {request_id}: {exc!s}') + return None + else: + return request + + async def _update_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest: + """Update a request in the queue. + + Args: + request: The updated request. + forefront: Whether to put the updated request in the beginning or the end of the queue. + + Returns: + The updated request + """ + response = await self._api_client.update_request( + request=request.model_dump(by_alias=True), + forefront=forefront, + ) + + return ProcessedRequest.model_validate( + {'id': request.id, 'uniqueKey': request.unique_key} | response, + ) + + async def _list_head( + self, + *, + lock_time: timedelta | None = None, + limit: int = 25, + ) -> RequestQueueHead: + """Retrieve requests from the beginning of the queue. + + Args: + lock_time: Duration for which to lock the retrieved requests. + If None, requests will not be locked. + limit: Maximum number of requests to retrieve. + + Returns: + A collection of requests from the beginning of the queue. + """ + # Return from cache if available and we're not checking for new forefront requests + if self._queue_head and not self._should_check_for_forefront_requests: + logger.debug(f'Using cached queue head with {len(self._queue_head)} requests') + + # Create a list of requests from the cached queue head + items = [] + for request_id in list(self._queue_head)[:limit]: + cached_request = self._requests_cache.get(request_id) + if cached_request and cached_request.hydrated: + items.append(cached_request.hydrated) + + return RequestQueueHead( + limit=limit, + had_multiple_clients=self._metadata.had_multiple_clients, + queue_modified_at=self._metadata.modified_at, + items=items, + queue_has_locked_requests=self._queue_has_locked_requests, + lock_time=lock_time, + ) + + # Otherwise fetch from API + lock_time = lock_time or self._DEFAULT_LOCK_TIME + lock_secs = int(lock_time.total_seconds()) + + response = await self._api_client.list_and_lock_head( + lock_secs=lock_secs, + limit=limit, + ) + + # Update the queue head cache + self._queue_has_locked_requests = response.get('queueHasLockedRequests', False) + + # Clear current queue head if we're checking for forefront requests + if self._should_check_for_forefront_requests: + self._queue_head.clear() + self._should_check_for_forefront_requests = False + + # Process and cache the requests + head_id_buffer = list[str]() + forefront_head_id_buffer = list[str]() + + for request_data in response.get('items', []): + request = Request.model_validate(request_data) + + # Skip requests without ID or unique key + if not request.id or not request.unique_key: + logger.debug( + 'Skipping request from queue head, missing ID or unique key', + extra={ + 'id': request.id, + 'unique_key': request.unique_key, + }, + ) + continue + + # Check if this request was already cached and if it was added to forefront + cache_key = unique_key_to_request_id(request.unique_key) + cached_request = self._requests_cache.get(cache_key) + forefront = cached_request.forefront if cached_request else False + + # Add to appropriate buffer based on forefront flag + if forefront: + forefront_head_id_buffer.insert(0, request.id) + else: + head_id_buffer.append(request.id) + + # Cache the request + self._cache_request( + cache_key, + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ), + forefront=forefront, + hydrated_request=request, + ) + + # Update the queue head deque + for request_id in head_id_buffer: + self._queue_head.append(request_id) + + for request_id in forefront_head_id_buffer: + self._queue_head.appendleft(request_id) + + return RequestQueueHead.model_validate(response) + + async def _prolong_request_lock( + self, + request_id: str, + *, + forefront: bool = False, + lock_secs: int, + ) -> ProlongRequestLockResponse: + """Prolong the lock on a specific request in the queue. + + Args: + request_id: The identifier of the request whose lock is to be prolonged. + forefront: Whether to put the request in the beginning or the end of the queue after lock expires. + lock_secs: The additional amount of time, in seconds, that the request will remain locked. + + Returns: + A response containing the time at which the lock will expire. + """ + response = await self._api_client.prolong_request_lock( + request_id=request_id, + forefront=forefront, + lock_secs=lock_secs, + ) + + result = ProlongRequestLockResponse( + lock_expires_at=datetime.fromisoformat(response['lockExpiresAt'].replace('Z', '+00:00')) + ) + + # Update the cache with the new lock expiration + for cached_request in self._requests_cache.values(): + if cached_request.id == request_id: + cached_request.lock_expires_at = result.lock_expires_at + break + + return result + + async def _delete_request_lock( + self, + request_id: str, + *, + forefront: bool = False, + ) -> None: + """Delete the lock on a specific request in the queue. + + Args: + request_id: ID of the request to delete the lock. + forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + """ + try: + await self._api_client.delete_request_lock( + request_id=request_id, + forefront=forefront, + ) + + # Update the cache to remove the lock + for cached_request in self._requests_cache.values(): + if cached_request.id == request_id: + cached_request.lock_expires_at = None + break + except Exception as err: + logger.debug(f'Failed to delete request lock for request {request_id}', exc_info=err) + + def _cache_request( + self, + cache_key: str, + processed_request: ProcessedRequest, + *, + forefront: bool, + hydrated_request: Request | None = None, + ) -> None: + """Cache a request for future use. + + Args: + cache_key: The key to use for caching the request. + processed_request: The processed request information. + forefront: Whether the request was added to the forefront of the queue. + hydrated_request: The hydrated request object, if available. + """ + self._requests_cache[cache_key] = CachedRequest( + id=processed_request.id, + was_already_handled=processed_request.was_already_handled, + hydrated=hydrated_request, + lock_expires_at=None, + forefront=forefront, + ) + + async def _update_metadata(self) -> None: + """Update the request queue metadata with current information.""" + metadata = await self._api_client.get() + self._metadata = RequestQueueMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_storage_client.py b/src/crawlee/storage_clients/_apify/_storage_client.py new file mode 100644 index 0000000000..1d4d66dd6a --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import ApifyDatasetClient +from ._key_value_store_client import ApifyKeyValueStoreClient +from ._request_queue_client import ApifyRequestQueueClient + + +class ApifyStorageClient(StorageClient): + """Apify storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/_apify/py.typed b/src/crawlee/storage_clients/_apify/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_base/__init__.py b/src/crawlee/storage_clients/_base/__init__.py index 5194da8768..73298560da 100644 --- a/src/crawlee/storage_clients/_base/__init__.py +++ b/src/crawlee/storage_clients/_base/__init__.py @@ -1,20 +1,11 @@ from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient from ._storage_client import StorageClient -from ._types import ResourceClient, ResourceCollectionClient __all__ = [ 'DatasetClient', - 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'RequestQueueClient', - 'RequestQueueCollectionClient', - 'ResourceClient', - 'ResourceCollectionClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index d8495b2dd0..c73eb6f51f 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,58 +7,76 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager + from typing import Any - from httpx import Response - - from crawlee._types import JsonSerializable + from crawlee.configuration import Configuration from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @docs_group('Abstract classes') class DatasetClient(ABC): - """An abstract class for dataset resource clients. + """An abstract class for dataset storage clients. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. - """ + Dataset clients provide an interface for accessing and manipulating dataset storage. They handle + operations like adding and getting dataset items across different storage backends. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - @abstractmethod - async def get(self) -> DatasetMetadata | None: - """Get metadata about the dataset being managed by this client. + This abstract class defines the interface that all specific dataset clients must implement. + """ - Returns: - An object containing the dataset's details, or None if the dataset does not exist. - """ + @property + @abstractmethod + def metadata(self) -> DatasetMetadata: + """The metadata of the dataset.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> DatasetMetadata: - """Update the dataset metadata. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> DatasetClient: + """Open existing or create a new dataset client. + + If a dataset with the given name or ID already exists, the appropriate dataset client is returned. + Otherwise, a new dataset is created and client for it is returned. + + The backend method for the `Dataset.open` call. Args: - name: New new name for the dataset. + id: The ID of the dataset. If not provided, an ID may be generated. + name: The name of the dataset. If not provided a default name may be used. + configuration: The configuration object. Returns: - An object reflecting the updated dataset metadata. + A dataset client instance. + """ + + @abstractmethod + async def drop(self) -> None: + """Drop the whole dataset and remove all its items. + + The backend method for the `Dataset.drop` call. """ @abstractmethod - async def delete(self) -> None: - """Permanently delete the dataset managed by this client.""" + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + """Push data to the dataset. + + The backend method for the `Dataset.push_data` call. + """ @abstractmethod - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -69,27 +87,9 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - """Retrieve a paginated list of items from a dataset based on various filtering parameters. - - This method provides the flexibility to filter, sort, and modify the appearance of dataset items - when listed. Each parameter modifies the result set according to its purpose. The method also - supports pagination through 'offset' and 'limit' parameters. - - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to return. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound. If it's an array, each element becomes a separate record. - skip_empty: If True, omits items that are empty after other filters have been applied. - skip_hidden: If True, omits fields starting with the '#' character. - flatten: A list of fields to flatten in each item. - view: The specific view of the dataset to use when retrieving items. + """Get data from the dataset with various filtering options. - Returns: - An object with filtered, sorted, and paginated dataset items plus pagination details. + The backend method for the `Dataset.get_data` call. """ @abstractmethod @@ -106,126 +106,12 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over items in the dataset according to specified filters and sorting. - - This method allows for asynchronously iterating through dataset items while applying various filters such as - skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` - parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and - `skip_hidden` parameters. + """Iterate over the dataset items with filtering options. - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to iterate over. None means no limit. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If set to True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound into separate items. - skip_empty: If set to True, omits items that are empty after other filters have been applied. - skip_hidden: If set to True, omits fields starting with the '#' character from the output. - - Yields: - An asynchronous iterator of dictionary objects, each representing a dataset item after applying - the specified filters and transformations. + The backend method for the `Dataset.iterate_items` call. """ # This syntax is to make mypy properly work with abstract AsyncIterator. # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators raise NotImplementedError if False: # type: ignore[unreachable] yield 0 - - @abstractmethod - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - """Retrieve dataset items as bytes. - - Args: - item_format: Output format (e.g., 'json', 'csv'); default is 'json'. - offset: Number of items to skip; default is 0. - limit: Max number of items to return; no default limit. - desc: If True, results are returned in descending order. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; default behavior varies by format. - delimiter: Delimiter character for CSV; default is ','. - fields: List of fields to include in the results. - omit: List of fields to omit from the results. - unwind: Unwinds a field into separate records. - skip_empty: If True, skips empty items in the output. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, skips hidden fields in the output. - xml_root: Root element name for XML output; default is 'items'. - xml_row: Element name for each item in XML output; default is 'item'. - flatten: List of fields to flatten. - - Returns: - The dataset items as raw bytes. - """ - - @abstractmethod - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - """Retrieve dataset items as a streaming response. - - Args: - item_format: Output format, options include json, jsonl, csv, html, xlsx, xml, rss; default is json. - offset: Number of items to skip at the start; default is 0. - limit: Maximum number of items to return; no default limit. - desc: If True, reverses the order of results. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; varies by format. - delimiter: Delimiter for CSV files; default is ','. - fields: List of fields to include in the output. - omit: List of fields to omit from the output. - unwind: Unwinds a field into separate records. - skip_empty: If True, empty items are omitted. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, hides fields starting with the # character. - xml_root: Custom root element name for XML output; default is 'items'. - xml_row: Custom element name for each item in XML; default is 'item'. - - Yields: - The dataset items in a streaming response. - """ - - @abstractmethod - async def push_items(self, items: JsonSerializable) -> None: - """Push items to the dataset. - - Args: - items: The items which to push in the dataset. They must be JSON serializable. - """ diff --git a/src/crawlee/storage_clients/_base/_dataset_collection_client.py b/src/crawlee/storage_clients/_base/_dataset_collection_client.py deleted file mode 100644 index 8530655c8c..0000000000 --- a/src/crawlee/storage_clients/_base/_dataset_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - - -@docs_group('Abstract classes') -class DatasetCollectionClient(ABC): - """An abstract class for dataset collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> DatasetMetadata: - """Retrieve an existing dataset by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the dataset to retrieve or create. If provided, the method will attempt - to find a dataset with the ID. - name: Optional name of the dataset resource to retrieve or create. If provided, the method will - attempt to find a dataset with this name. - schema: Optional schema for the dataset resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created dataset. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - """List the available datasets. - - Args: - unnamed: Whether to list only the unnamed datasets. - limit: Maximum number of datasets to return. - offset: Number of datasets to skip from the beginning of the list. - desc: Whether to sort the datasets in descending order. - - Returns: - The list of available datasets matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 6a5d141be6..957b53db0e 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -6,126 +6,105 @@ from crawlee._utils.docs import docs_group if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - - from crawlee.storage_clients.models import KeyValueStoreListKeysPage, KeyValueStoreMetadata, KeyValueStoreRecord + from crawlee.configuration import Configuration + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @docs_group('Abstract classes') class KeyValueStoreClient(ABC): - """An abstract class for key-value store resource clients. + """An abstract class for key-value store (KVS) storage clients. + + Key-value stores clients provide an interface for accessing and manipulating KVS storage. They handle + operations like getting, setting, deleting KVS values across different storage backends. + + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. + This abstract class defines the interface that all specific KVS clients must implement. """ + @property @abstractmethod - async def get(self) -> KeyValueStoreMetadata | None: - """Get metadata about the key-value store being managed by this client. - - Returns: - An object containing the key-value store's details, or None if the key-value store does not exist. - """ + def metadata(self) -> KeyValueStoreMetadata: + """The metadata of the key-value store.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> KeyValueStoreMetadata: - """Update the key-value store metadata. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> KeyValueStoreClient: + """Open existing or create a new key-value store client. - Args: - name: New new name for the key-value store. + If a key-value store with the given name or ID already exists, the appropriate + key-value store client is returned. Otherwise, a new key-value store is created + and a client for it is returned. - Returns: - An object reflecting the updated key-value store metadata. - """ - - @abstractmethod - async def delete(self) -> None: - """Permanently delete the key-value store managed by this client.""" - - @abstractmethod - async def list_keys( - self, - *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - """List the keys in the key-value store. + The backend method for the `KeyValueStoreClient.open` call. Args: - limit: Number of keys to be returned. Maximum value is 1000. - exclusive_start_key: All keys up to this one (including) are skipped from the result. + id: The ID of the key-value store. If not provided, an ID may be generated. + name: The name of the key-value store. If not provided a default name may be used. + configuration: The configuration object. Returns: - The list of keys in the key-value store matching the given arguments. + A key-value store client instance. """ @abstractmethod - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - """Retrieve the given record from the key-value store. + async def drop(self) -> None: + """Drop the whole key-value store and remove all its values. - Args: - key: Key of the record to retrieve. - - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.drop` call. """ @abstractmethod - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - """Retrieve the given record from the key-value store, without parsing it. - - Args: - key: Key of the record to retrieve. + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + """Retrieve the given record from the key-value store. - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.get_value` call. """ @abstractmethod - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - """Retrieve the given record from the key-value store, as a stream. - - Args: - key: Key of the record to retrieve. + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + """Set a value in the key-value store by its key. - Returns: - The requested record as a context-managed streaming Response, or None, if the record does not exist + The backend method for the `KeyValueStore.set_value` call. """ @abstractmethod - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - """Set a value to the given record in the key-value store. + async def delete_value(self, *, key: str) -> None: + """Delete a value from the key-value store by its key. - Args: - key: The key of the record to save the value to. - value: The value to save into the record. - content_type: The content type of the saved value. + The backend method for the `KeyValueStore.delete_value` call. """ @abstractmethod - async def delete_record(self, key: str) -> None: - """Delete the specified record from the key-value store. + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over all the existing keys in the key-value store. - Args: - key: The key of the record which to delete. + The backend method for the `KeyValueStore.iterate_keys` call. """ + # This syntax is to make mypy properly work with abstract AsyncIterator. + # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators + raise NotImplementedError + if False: # type: ignore[unreachable] + yield 0 @abstractmethod - async def get_public_url(self, key: str) -> str: + async def get_public_url(self, *, key: str) -> str: """Get the public URL for the given key. - Args: - key: Key of the record for which URL is required. - - Returns: - The public URL for the given key. - - Raises: - ValueError: If the key does not exist. + The backend method for the `KeyValueStore.get_public_url` call. """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py deleted file mode 100644 index b447cf49b1..0000000000 --- a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - - -@docs_group('Abstract classes') -class KeyValueStoreCollectionClient(ABC): - """An abstract class for key-value store collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> KeyValueStoreMetadata: - """Retrieve an existing key-value store by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the key-value store to retrieve or create. If provided, the method will attempt - to find a key-value store with the ID. - name: Optional name of the key-value store resource to retrieve or create. If provided, the method will - attempt to find a key-value store with this name. - schema: Optional schema for the key-value store resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created key-value store. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - """List the available key-value stores. - - Args: - unnamed: Whether to list only the unnamed key-value stores. - limit: Maximum number of key-value stores to return. - offset: Number of key-value stores to skip from the beginning of the list. - desc: Whether to sort the key-value stores in descending order. - - Returns: - The list of available key-value stores matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 06b180801a..7f2cdc11f1 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -8,13 +8,11 @@ if TYPE_CHECKING: from collections.abc import Sequence + from crawlee.configuration import Configuration from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, + AddRequestsResponse, ProcessedRequest, - ProlongRequestLockResponse, Request, - RequestQueueHead, - RequestQueueHeadWithLocks, RequestQueueMetadata, ) @@ -27,91 +25,63 @@ class RequestQueueClient(ABC): client, like a memory storage client. """ + @property @abstractmethod - async def get(self) -> RequestQueueMetadata | None: - """Get metadata about the request queue being managed by this client. - - Returns: - An object containing the request queue's details, or None if the request queue does not exist. - """ + def metadata(self) -> RequestQueueMetadata: + """The metadata of the request queue.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> RequestQueueMetadata: - """Update the request queue metadata. - - Args: - name: New new name for the request queue. - - Returns: - An object reflecting the updated request queue metadata. - """ - - @abstractmethod - async def delete(self) -> None: - """Permanently delete the request queue managed by this client.""" - - @abstractmethod - async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: - """Retrieve a given number of requests from the beginning of the queue. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> RequestQueueClient: + """Open a request queue client. Args: - limit: How many requests to retrieve. + id: ID of the queue to open. If not provided, a new queue will be created with a random ID. + name: Name of the queue to open. If not provided, the queue will be unnamed. + configuration: The configuration object. Returns: - The desired number of requests from the beginning of the queue. + A request queue client. """ @abstractmethod - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - """Fetch and lock a specified number of requests from the start of the queue. - - Retrieve and locks the first few requests of a queue for the specified duration. This prevents the requests - from being fetched by another client until the lock expires. - - Args: - lock_secs: Duration for which the requests are locked, in seconds. - limit: Maximum number of requests to retrieve and lock. - - Returns: - The desired number of locked requests from the beginning of the queue. - """ - - @abstractmethod - async def add_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Add a request to the queue. - - Args: - request: The request to add to the queue. - forefront: Whether to add the request to the head or the end of the queue. + async def drop(self) -> None: + """Drop the whole request queue and remove all its values. - Returns: - Request queue operation information. + The backend method for the `RequestQueue.drop` call. """ @abstractmethod - async def batch_add_requests( + async def add_batch_of_requests( self, requests: Sequence[Request], *, forefront: bool = False, - ) -> BatchRequestsOperationResponse: - """Add a batch of requests to the queue. + ) -> AddRequestsResponse: + """Add batch of requests to the queue. + + This method adds a batch of requests to the queue. Each request is processed based on its uniqueness + (determined by `unique_key`). Duplicates will be identified but not re-added to the queue. Args: - requests: The requests to add to the queue. - forefront: Whether to add the requests to the head or the end of the queue. + requests: The collection of requests to add to the queue. + forefront: Whether to put the added requests at the beginning (True) or the end (False) of the queue. + When True, the requests will be processed sooner than previously added requests. + batch_size: The maximum number of requests to add in a single batch. + wait_time_between_batches: The time to wait between adding batches of requests. + wait_for_all_requests_to_be_added: If True, the method will wait until all requests are added + to the queue before returning. + wait_for_all_requests_to_be_added_timeout: The maximum time to wait for all requests to be added. Returns: - Request queue batch operation information. + A response object containing information about which requests were successfully + processed and which failed (if any). """ @abstractmethod @@ -126,64 +96,58 @@ async def get_request(self, request_id: str) -> Request | None: """ @abstractmethod - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Update a request in the queue. + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. - Args: - request: The updated request. - forefront: Whether to put the updated request in the beginning or the end of the queue. + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Note that the `None` return value does not mean the queue processing finished, it means there are currently + no pending requests. To check whether all requests in queue were finished, use `RequestQueue.is_finished` + instead. Returns: - The updated request + The request or `None` if there are no more pending requests. """ @abstractmethod - async def delete_request(self, request_id: str) -> None: - """Delete a request from the queue. - - Args: - request_id: ID of the request to delete. - """ + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. - @abstractmethod - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - """Prolong the lock on a specific request in the queue. + Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. Args: - request_id: The identifier of the request whose lock is to be prolonged. - forefront: Whether to put the request in the beginning or the end of the queue after lock expires. - lock_secs: The additional amount of time, in seconds, that the request will remain locked. + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def delete_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - ) -> None: - """Delete the lock on a specific request in the queue. + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. Args: - request_id: ID of the request to delete the lock. - forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - """Delete given requests from the queue. + async def is_empty(self) -> bool: + """Check if the request queue is empty. - Args: - requests: The requests to delete from the queue. + Returns: + True if the request queue is empty, False otherwise. """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py b/src/crawlee/storage_clients/_base/_request_queue_collection_client.py deleted file mode 100644 index 7de876c344..0000000000 --- a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - - -@docs_group('Abstract classes') -class RequestQueueCollectionClient(ABC): - """An abstract class for request queue collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> RequestQueueMetadata: - """Retrieve an existing request queue by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the request queue to retrieve or create. If provided, the method will attempt - to find a request queue with the ID. - name: Optional name of the request queue resource to retrieve or create. If provided, the method will - attempt to find a request queue with this name. - schema: Optional schema for the request queue resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created request queue. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - """List the available request queues. - - Args: - unnamed: Whether to list only the unnamed request queues. - limit: Maximum number of request queues to return. - offset: Number of request queues to skip from the beginning of the list. - desc: Whether to sort the request queues in descending order. - - Returns: - The list of available request queues matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index 4f022cf30a..fefa7ea5cb 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -1,62 +1,45 @@ -# Inspiration: https://github.com/apify/crawlee/blob/v3.8.2/packages/types/src/storages.ts#L314:L328 - from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING -from crawlee._utils.docs import docs_group - if TYPE_CHECKING: + from crawlee.configuration import Configuration + from ._dataset_client import DatasetClient - from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient - from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient - from ._request_queue_collection_client import RequestQueueCollectionClient -@docs_group('Abstract classes') class StorageClient(ABC): - """Defines an abstract base for storage clients. - - It offers interfaces to get subclients for interacting with storage resources like datasets, key-value stores, - and request queues. - """ - - @abstractmethod - def dataset(self, id: str) -> DatasetClient: - """Get a subclient for a specific dataset by its ID.""" - - @abstractmethod - def datasets(self) -> DatasetCollectionClient: - """Get a subclient for dataset collection operations.""" - - @abstractmethod - def key_value_store(self, id: str) -> KeyValueStoreClient: - """Get a subclient for a specific key-value store by its ID.""" + """Base class for storage clients.""" @abstractmethod - def key_value_stores(self) -> KeyValueStoreCollectionClient: - """Get a subclient for key-value store collection operations.""" + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> DatasetClient: + """Open a dataset client.""" @abstractmethod - def request_queue(self, id: str) -> RequestQueueClient: - """Get a subclient for a specific request queue by its ID.""" + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> KeyValueStoreClient: + """Open a key-value store client.""" @abstractmethod - def request_queues(self) -> RequestQueueCollectionClient: - """Get a subclient for request queue collection operations.""" - - @abstractmethod - async def purge_on_start(self) -> None: - """Perform a purge of the default storages. - - This method ensures that the purge is executed only once during the lifetime of the instance. - It is primarily used to clean up residual data from previous runs to maintain a clean state. - If the storage client does not support purging, leave it empty. - """ - - def get_rate_limit_errors(self) -> dict[int, int]: - """Return statistics about rate limit errors encountered by the HTTP client in storage client.""" - return {} + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> RequestQueueClient: + """Open a request queue client.""" diff --git a/src/crawlee/storage_clients/_base/_types.py b/src/crawlee/storage_clients/_base/_types.py deleted file mode 100644 index a5cf1325f5..0000000000 --- a/src/crawlee/storage_clients/_base/_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -ResourceClient = Union[ - DatasetClient, - KeyValueStoreClient, - RequestQueueClient, -] - -ResourceCollectionClient = Union[ - DatasetCollectionClient, - KeyValueStoreCollectionClient, - RequestQueueCollectionClient, -] diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py new file mode 100644 index 0000000000..2169896d86 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -0,0 +1,11 @@ +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient +from ._storage_client import FileSystemStorageClient + +__all__ = [ + 'FileSystemDatasetClient', + 'FileSystemKeyValueStoreClient', + 'FileSystemRequestQueueClient', + 'FileSystemStorageClient', +] diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py new file mode 100644 index 0000000000..5db837612f --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from typing import Any + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class FileSystemDatasetClient(DatasetClient): + """File system implementation of the dataset client. + + This client persists dataset items to the file system as individual JSON files within a structured + directory hierarchy following the pattern: + + ``` + {STORAGE_DIR}/datasets/{DATASET_ID}/{ITEM_ID}.json + ``` + + Each item is stored as a separate file, which allows for durability and the ability to + recover after process termination. Dataset operations like filtering, sorting, and pagination are + implemented by processing the stored files according to the requested parameters. + + This implementation is ideal for long-running crawlers where data persistence is important, + and for development environments where you want to easily inspect the collected data between runs. + """ + + _STORAGE_SUBDIR = 'datasets' + """The name of the subdirectory where datasets are stored.""" + + _ITEM_FILENAME_DIGITS = 9 + """Number of digits used for the dataset item file names (e.g., 000000019.json).""" + + _cache_by_name: ClassVar[dict[str, FileSystemDatasetClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> DatasetMetadata: + return self._metadata + + @property + def path_to_dataset(self) -> Path: + """The full path to the dataset directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the dataset metadata file.""" + return self.path_to_dataset / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemDatasetClient: + if id: + raise ValueError( + 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_dataset_id + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + storage_dir = Path(configuration.storage_dir) + dataset_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = dataset_path / METADATA_FILENAME + + # If the dataset directory exists, reconstruct the client from the metadata file. + if dataset_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for dataset "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = DatasetMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for dataset "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + item_count=metadata.item_count, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_dataset.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.metadata.item_count + + # If data is a list, push each item individually. + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item, new_item_count) + else: + new_item_count += 1 + await self._push_item(data, new_item_count) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + # Check for unsupported arguments and log a warning if found. + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported by the ' + f'{self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return an empty page. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return DatasetItemsListPage( + count=0, + offset=offset, + limit=limit or 0, + total=0, + desc=desc, + items=[], + ) + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + total = len(data_files) + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Read and parse each data file. + items = [] + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + items.append(item) + + await self._update_metadata(update_accessed_at=True) + + # Return a paginated list page of dataset items. + return DatasetItemsListPage( + count=len(items), + offset=offset, + limit=limit or total - offset, + total=total, + desc=desc, + items=items, + ) + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + # Check for unsupported arguments and log a warning if found. + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return immediately. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Iterate over each data file, reading and yielding its parsed content. + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + yield item + + await self._update_metadata(update_accessed_at=True) + + async def _update_metadata( + self, + *, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count is not None: + self._metadata.item_count = new_item_count + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') + + async def _push_item(self, item: dict[str, Any], item_id: int) -> None: + """Push a single item to the dataset. + + This method writes the item as a JSON file with a zero-padded numeric filename + that reflects its position in the dataset sequence. + + Args: + item: The data item to add to the dataset. + item_id: The sequential ID to use for this item's filename. + """ + # Acquire the lock to perform file operations safely. + async with self._lock: + # Generate the filename for the new item using zero-padded numbering. + filename = f'{str(item_id).zfill(self._ITEM_FILENAME_DIGITS)}.json' + file_path = self.path_to_dataset / filename + + # Ensure the dataset directory exists. + await asyncio.to_thread(self.path_to_dataset.mkdir, parents=True, exist_ok=True) + + # Dump the serialized item to the file. + data = await json_dumps(item) + await asyncio.to_thread(file_path.write_text, data, encoding='utf-8') + + async def _get_sorted_data_files(self) -> list[Path]: + """Retrieve and return a sorted list of data files in the dataset directory. + + The files are sorted numerically based on the filename (without extension), + which corresponds to the order items were added to the dataset. + + Returns: + A list of `Path` objects pointing to data files, sorted by numeric filename. + """ + # Retrieve and sort all JSON files in the dataset directory numerically. + files = await asyncio.to_thread( + sorted, + self.path_to_dataset.glob('*.json'), + key=lambda f: int(f.stem) if f.stem.isdigit() else 0, + ) + + # Remove the metadata file from the list if present. + if self.path_to_metadata in files: + files.remove(self.path_to_metadata) + + return files diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py new file mode 100644 index 0000000000..f7db025a25 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from crawlee.configuration import Configuration + + +logger = getLogger(__name__) + + +class FileSystemKeyValueStoreClient(KeyValueStoreClient): + """File system implementation of the key-value store client. + + This client persists data to the file system, making it suitable for scenarios where data needs to + survive process restarts. Keys are mapped to file paths in a directory structure following the pattern: + + ``` + {STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY} + ``` + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. + """ + + _STORAGE_SUBDIR = 'key_value_stores' + """The name of the subdirectory where key-value stores are stored.""" + + _cache_by_name: ClassVar[dict[str, FileSystemKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @property + def path_to_kvs(self) -> Path: + """The full path to the key-value store directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the key-value store metadata file.""" + return self.path_to_kvs / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemKeyValueStoreClient: + if id: + raise ValueError( + 'Opening a key-value store by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_dataset_id + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + storage_dir = Path(configuration.storage_dir) + kvs_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = kvs_path / METADATA_FILENAME + + # If the key-value store directory exists, reconstruct the client from the metadata file. + if kvs_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for key-value store "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = KeyValueStoreMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for key-value store "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new key-value store client. + else: + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + storage_dir=storage_dir, + ) + await client._update_metadata() + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_kvs.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + # Update the metadata to record access + await self._update_metadata(update_accessed_at=True) + + record_path = self.path_to_kvs / key + + if not record_path.exists(): + return None + + # Found a file for this key, now look for its metadata + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + if not record_metadata_filepath.exists(): + logger.warning(f'Found value file for key "{key}" but no metadata file.') + return None + + # Read the metadata file + async with self._lock: + file = await asyncio.to_thread(open, record_metadata_filepath) + try: + metadata_content = json.load(file) + except json.JSONDecodeError: + logger.warning(f'Invalid metadata file for key "{key}"') + return None + finally: + await asyncio.to_thread(file.close) + + try: + metadata = KeyValueStoreRecordMetadata(**metadata_content) + except ValidationError: + logger.warning(f'Invalid metadata schema for key "{key}"') + return None + + # Read the actual value + value_bytes = await asyncio.to_thread(record_path.read_bytes) + + # Handle JSON values + if 'application/json' in metadata.content_type: + try: + value = json.loads(value_bytes.decode('utf-8')) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning(f'Failed to decode JSON value for key "{key}"') + return None + + # Handle text values + elif metadata.content_type.startswith('text/'): + try: + value = value_bytes.decode('utf-8') + except UnicodeDecodeError: + logger.warning(f'Failed to decode text value for key "{key}"') + return None + + # Handle binary values + else: + value = value_bytes + + # Calculate the size of the value in bytes + size = len(value_bytes) + + return KeyValueStoreRecord( + key=metadata.key, + value=value, + content_type=metadata.content_type, + size=size, + ) + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + + # Serialize the value to bytes. + if 'application/json' in content_type: + value_bytes = (await json_dumps(value)).encode('utf-8') + elif isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, (bytes, bytearray)): + value_bytes = value + else: + # Fallback: attempt to convert to string and encode. + value_bytes = str(value).encode('utf-8') + + record_path = self.path_to_kvs / key + + # Prepare the metadata + size = len(value_bytes) + record_metadata = KeyValueStoreRecordMetadata(key=key, content_type=content_type, size=size) + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + record_metadata_content = await json_dumps(record_metadata.model_dump()) + + async with self._lock: + # Ensure the key-value store directory exists. + await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True) + + # Write the value to the file. + await asyncio.to_thread(record_path.write_bytes, value_bytes) + + # Write the record metadata to the file. + await asyncio.to_thread( + record_metadata_filepath.write_text, + record_metadata_content, + encoding='utf-8', + ) + + # Update the KVS metadata to record the access and modification. + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def delete_value(self, *, key: str) -> None: + record_path = self.path_to_kvs / key + metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + deleted = False + + async with self._lock: + # Delete the value file and its metadata if found + if record_path.exists(): + await asyncio.to_thread(record_path.unlink) + + # Delete the metadata file if it exists + if metadata_path.exists(): + await asyncio.to_thread(metadata_path.unlink) + else: + logger.warning(f'Found value file for key "{key}" but no metadata file when trying to delete it.') + + deleted = True + + # If we deleted something, update the KVS metadata + if deleted: + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + # Check if the KVS directory exists + if not self.path_to_kvs.exists(): + return + + count = 0 + async with self._lock: + # Get all files in the KVS directory, sorted alphabetically + files = sorted(await asyncio.to_thread(list, self.path_to_kvs.glob('*'))) + + for file_path in files: + # Skip the main metadata file + if file_path.name == METADATA_FILENAME: + continue + + # Only process metadata files for records + if not file_path.name.endswith(f'.{METADATA_FILENAME}'): + continue + + # Extract the base key name from the metadata filename + key_name = file_path.name[: -len(f'.{METADATA_FILENAME}')] + + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None and key_name <= exclusive_start_key: + continue + + # Try to read and parse the metadata file + try: + metadata_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + metadata_dict = json.loads(metadata_content) + record_metadata = KeyValueStoreRecordMetadata(**metadata_dict) + + yield record_metadata + + count += 1 + if limit and count >= limit: + break + + except (json.JSONDecodeError, ValidationError) as e: + logger.warning(f'Failed to parse metadata file {file_path}: {e}') + + # Update accessed_at timestamp + await self._update_metadata(update_accessed_at=True) + + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for file system key-value stores.') + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py new file mode 100644 index 0000000000..a88168e894 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee import Request +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class FileSystemRequestQueueClient(RequestQueueClient): + """A file system implementation of the request queue client. + + This client persists requests to the file system as individual JSON files, making it suitable for scenarios + where data needs to survive process restarts. Each request is stored as a separate file in a directory + structure following the pattern: + + ``` + {STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json + ``` + + The implementation uses file timestamps for FIFO ordering of regular requests and maintains in-memory sets + for tracking in-progress and forefront requests. File system storage provides durability at the cost of + slower I/O operations compared to memory-based storage. + + This implementation is ideal for long-running crawlers where persistence is important and for situations + where you need to resume crawling after process termination. + """ + + _STORAGE_SUBDIR = 'request_queues' + """The name of the subdirectory where request queues are stored.""" + + _cache_by_name: ClassVar[dict[str, FileSystemRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + self._in_progress = set[str]() + """A set of request IDs that are currently being processed.""" + + self._forefront_requests = set[str]() + """A set of request IDs that should be prioritized (added with forefront=True).""" + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @property + def path_to_rq(self) -> Path: + """The full path to the request queue directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the request queue metadata file.""" + return self.path_to_rq / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemRequestQueueClient: + if id: + raise ValueError( + 'Opening a request queue by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_request_queue_id + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + storage_dir = Path(configuration.storage_dir) + rq_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = rq_path / METADATA_FILENAME + + # If the RQ directory exists, reconstruct the client from the metadata file. + if rq_path.exists() and not configuration.purge_on_start: + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for request queue "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = RequestQueueMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for request queue "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + had_multiple_clients=metadata.had_multiple_clients, + handled_request_count=metadata.handled_request_count, + pending_request_count=metadata.pending_request_count, + stats=metadata.stats, + total_request_count=metadata.total_request_count, + storage_dir=storage_dir, + ) + + # Recalculate request counts from actual files to ensure consistency + handled_count = 0 + pending_count = 0 + request_files = await asyncio.to_thread(list, rq_path.glob('*.json')) + for request_file in request_files: + if request_file.name == METADATA_FILENAME: + continue + + try: + file = await asyncio.to_thread(open, request_file) + try: + data = json.load(file) + if data.get('handled_at') is not None: + handled_count += 1 + else: + pending_count += 1 + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + + await client._update_metadata( + update_accessed_at=True, + new_handled_request_count=handled_count, + new_pending_request_count=pending_count, + new_total_request_count=handled_count + pending_count, + ) + + # Otherwise, create a new dataset client. + else: + # If purge_on_start is true and the directory exists, remove it + if configuration.purge_on_start and rq_path.exists(): + await asyncio.to_thread(shutil.rmtree, rq_path) + + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_rq.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_rq) + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + async with self._lock: + new_total_request_count = self._metadata.total_request_count + new_pending_request_count = self._metadata.pending_request_count + + processed_requests = [] + + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() + + # Check if the request is already in the queue by unique_key + existing_request = None + + # List all request files and check for matching unique_key + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + if file_content.get('unique_key') == request.unique_key: + existing_request = Request(**file_content) + break + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + was_already_present = existing_request is not None + was_already_handled = ( + was_already_present and existing_request and existing_request.handled_at is not None + ) + + # If the request is already in the queue and handled, don't add it again + if was_already_handled and existing_request: + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + continue + + # If forefront and existing request is not handled, mark it as forefront + if forefront and was_already_present and not was_already_handled and existing_request: + self._forefront_requests.add(existing_request.id) + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + + # If the request is already in the queue but not handled, update it + if was_already_present and existing_request: + # Update the existing request file + request_path = self.path_to_rq / f'{existing_request.id}.json' + request_data = await json_dumps(existing_request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + + # Add the new request to the queue + request_path = self.path_to_rq / f'{request.id}.json' + + # Create a data dictionary from the request and remove handled_at if it's None + request_dict = request.model_dump() + if request_dict.get('handled_at') is None: + request_dict.pop('handled_at', None) + + request_data = await json_dumps(request_dict) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata counts + new_total_request_count += 1 + new_pending_request_count += 1 + + # If forefront, add to the forefront set + if forefront: + self._forefront_requests.add(request.id) + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) + ) + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_total_request_count=new_total_request_count, + new_pending_request_count=new_pending_request_count, + ) + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], + ) + + @override + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. + + Args: + request_id: ID of the request to retrieve. + + Returns: + The retrieved request, or None, if it did not exist. + """ + request_path = self.path_to_rq / f'{request_id}.json' + + if await asyncio.to_thread(request_path.exists): + file = await asyncio.to_thread(open, request_path) + try: + file_content = json.load(file) + return Request(**file_content) + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_path}: {exc!s}') + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Returns: + The request or `None` if there are no more pending requests. + """ + async with self._lock: + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + + # First check for forefront requests + forefront_requests = [] + regular_requests = [] + + # Get file creation times for sorting regular requests in FIFO order + request_file_times = {} + + # Separate requests into forefront and regular + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + # Extract request ID from filename + request_id = request_file.stem + + # Skip if already in progress + if request_id in self._in_progress: + continue + + # Get file creation/modification time for FIFO ordering + try: + file_stat = await asyncio.to_thread(request_file.stat) + request_file_times[request_file] = file_stat.st_mtime + except Exception: + # If we can't get the time, use 0 (oldest) + request_file_times[request_file] = 0 + + if request_id in self._forefront_requests: + forefront_requests.append(request_file) + else: + regular_requests.append(request_file) + + # Sort regular requests by creation time (FIFO order) + regular_requests.sort(key=lambda f: request_file_times[f]) + + # Prioritize forefront requests + prioritized_files = forefront_requests + regular_requests + + # Process files in prioritized order + for request_file in prioritized_files: + request_id = request_file.stem + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # Skip if already handled + if file_content.get('handled_at') is not None: + continue + + # Create request object + request = Request(**file_content) + + # Mark as in-progress in memory + self._in_progress.add(request.id) + + # Remove from forefront set if it was there + self._forefront_requests.discard(request.id) + + # Update accessed timestamp + await self._update_metadata(update_accessed_at=True) + + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_file}: {exc!s}') + else: + return request + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Remove from in-progress set + self._in_progress.discard(request.id) + + # Update the request object - set handled_at timestamp + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + + # Write the updated request back to the requests directory + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata timestamps + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_handled_request_count=self._metadata.handled_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count - 1, + ) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Remove from in-progress set + self._in_progress.discard(request.id) + + # If forefront is true, mark this request as priority + if forefront: + self._forefront_requests.add(request.id) + else: + # Make sure it's not in the forefront set if it was previously added there + self._forefront_requests.discard(request.id) + + # To simulate changing the file timestamp for FIFO ordering, + # we'll update the file with current timestamp + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + # Update accessed timestamp when checking if queue is empty + await self._update_metadata(update_accessed_at=True) + + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + + # Check each file to see if there are any unhandled requests + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # If any request is not handled, the queue is not empty + if file_content.get('handled_at') is None: + return False + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + # If we got here, all requests are handled or there are no requests + return True + + async def _update_metadata( + self, + *, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + update_had_multiple_clients: bool = False, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + update_had_multiple_clients: If True, set had_multiple_clients to True. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + # Always create a new timestamp to ensure it's truly updated + now = datetime.now(timezone.utc) + + # Update timestamps according to parameters + if update_accessed_at: + self._metadata.accessed_at = now + + if update_modified_at: + self._metadata.modified_at = now + + # Update request counts if provided + if new_handled_request_count is not None: + self._metadata.handled_request_count = new_handled_request_count + + if new_pending_request_count is not None: + self._metadata.pending_request_count = new_pending_request_count + + if new_total_request_count is not None: + self._metadata.total_request_count = new_total_request_count + + if update_had_multiple_clients: + self._metadata.had_multiple_clients = True + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py new file mode 100644 index 0000000000..2765d15536 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient + + +class FileSystemStorageClient(StorageClient): + """File system storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/_file_system/_utils.py b/src/crawlee/storage_clients/_file_system/_utils.py new file mode 100644 index 0000000000..c172df50cc --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_utils.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +METADATA_FILENAME = '__metadata__.json' +"""The name of the metadata file for storage clients.""" + + +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. + + Args: + obj: The object to serialize. + + Returns: + A string containing the JSON representation of the input object. + """ + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) diff --git a/src/crawlee/storage_clients/_file_system/py.typed b/src/crawlee/storage_clients/_file_system/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 09912e124d..3746907b4f 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,17 +1,11 @@ -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._memory_storage_client import MemoryStorageClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient +from ._storage_client import MemoryStorageClient __all__ = [ - 'DatasetClient', - 'DatasetCollectionClient', - 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', + 'MemoryDatasetClient', + 'MemoryKeyValueStoreClient', + 'MemoryRequestQueueClient', 'MemoryStorageClient', - 'RequestQueueClient', - 'RequestQueueCollectionClient', ] diff --git a/src/crawlee/storage_clients/_memory/_creation_management.py b/src/crawlee/storage_clients/_memory/_creation_management.py deleted file mode 100644 index f6d4fc1c91..0000000000 --- a/src/crawlee/storage_clients/_memory/_creation_management.py +++ /dev/null @@ -1,429 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import mimetypes -import os -import pathlib -from datetime import datetime, timezone -from logging import getLogger -from typing import TYPE_CHECKING - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import ( - DatasetMetadata, - InternalRequest, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, - RequestQueueMetadata, -) - -if TYPE_CHECKING: - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._memory_storage_client import MemoryStorageClient, TResourceClient - from ._request_queue_client import RequestQueueClient - -logger = getLogger(__name__) - - -async def persist_metadata_if_enabled(*, data: dict, entity_directory: str, write_metadata: bool) -> None: - """Update or writes metadata to a specified directory. - - The function writes a given metadata dictionary to a JSON file within a specified directory. - The writing process is skipped if `write_metadata` is False. Before writing, it ensures that - the target directory exists, creating it if necessary. - - Args: - data: A dictionary containing metadata to be written. - entity_directory: The directory path where the metadata file should be stored. - write_metadata: A boolean flag indicating whether the metadata should be written to file. - """ - # Skip metadata write; ensure directory exists first - if not write_metadata: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the metadata to the file - file_path = os.path.join(entity_directory, METADATA_FILENAME) - f = await asyncio.to_thread(open, file_path, mode='wb') - try: - s = await json_dumps(data) - await asyncio.to_thread(f.write, s.encode('utf-8')) - finally: - await asyncio.to_thread(f.close) - - -def find_or_create_client_by_id_or_name_inner( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> TResourceClient | None: - """Locate or create a new storage client based on the given ID or name. - - This method attempts to find a storage client in the memory cache first. If not found, - it tries to locate a storage directory by name. If still not found, it searches through - storage directories for a matching ID or name in their metadata. If none exists, and the - specified ID is 'default', it checks for a default storage directory. If a storage client - is found or created, it is added to the memory cache. If no storage client can be located or - created, the method returns None. - - Args: - resource_client_class: The class of the resource client. - memory_storage_client: The memory storage client used to store and retrieve storage clients. - id: The unique identifier for the storage client. - name: The name of the storage client. - - Raises: - ValueError: If both id and name are None. - - Returns: - The found or created storage client, or None if no client could be found or created. - """ - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._request_queue_client import RequestQueueClient - - if id is None and name is None: - raise ValueError('Either id or name must be specified.') - - # First check memory cache - found = memory_storage_client.get_cached_resource_client(resource_client_class, id, name) - - if found is not None: - return found - - storage_path = _determine_storage_path(resource_client_class, memory_storage_client, id, name) - - if not storage_path: - return None - - # Create from directory if storage path is found - if issubclass(resource_client_class, DatasetClient): - resource_client = create_dataset_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, KeyValueStoreClient): - resource_client = create_kvs_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, RequestQueueClient): - resource_client = create_rq_from_directory(storage_path, memory_storage_client, id, name) - else: - raise TypeError('Invalid resource client class.') - - memory_storage_client.add_resource_client_to_cache(resource_client) - return resource_client - - -async def get_or_create_inner( - *, - memory_storage_client: MemoryStorageClient, - storage_client_cache: list[TResourceClient], - resource_client_class: type[TResourceClient], - name: str | None = None, - id: str | None = None, -) -> TResourceClient: - """Retrieve a named storage, or create a new one when it doesn't exist. - - Args: - memory_storage_client: The memory storage client. - storage_client_cache: The cache of storage clients. - resource_client_class: The class of the storage to retrieve or create. - name: The name of the storage to retrieve or create. - id: ID of the storage to retrieve or create. - - Returns: - The retrieved or newly-created storage. - """ - # If the name or id is provided, try to find the dataset in the cache - if name or id: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=resource_client_class, - memory_storage_client=memory_storage_client, - name=name, - id=id, - ) - if found: - return found - - # Otherwise, create a new one and add it to the cache - resource_client = resource_client_class( - id=id, - name=name, - memory_storage_client=memory_storage_client, - ) - - storage_client_cache.append(resource_client) - - # Write to the disk - await persist_metadata_if_enabled( - data=resource_client.resource_info.model_dump(), - entity_directory=resource_client.resource_directory, - write_metadata=memory_storage_client.write_metadata, - ) - - return resource_client - - -def create_dataset_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> DatasetClient: - from ._dataset_client import DatasetClient - - item_count = 0 - has_seen_metadata_file = False - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - has_seen_metadata_file = True - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = DatasetMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - item_count = resource_info.item_count - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Load dataset entries - entries: dict[str, dict] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - has_seen_metadata_file = True - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - entry_content = json.load(f) - - entry_name = entry.name.split('.')[0] - entries[entry_name] = entry_content - - if not has_seen_metadata_file: - item_count += 1 - - # Create new dataset client - new_client = DatasetClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - item_count=item_count, - ) - - new_client.dataset_entries.update(entries) - return new_client - - -def create_kvs_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> KeyValueStoreClient: - from ._key_value_store_client import KeyValueStoreClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = KeyValueStoreMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Create new KVS client - new_client = KeyValueStoreClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - ) - - # Scan the KVS folder, check each entry in there and parse it as a store record - for entry in os.scandir(storage_directory): - if not entry.is_file(): - continue - - # Ignore metadata files on their own - if entry.name.endswith(METADATA_FILENAME): - continue - - # Try checking if this file has a metadata file associated with it - record_metadata = None - record_metadata_filepath = os.path.join(storage_directory, f'{entry.name}.__metadata__.json') - - if os.path.exists(record_metadata_filepath): - with open(record_metadata_filepath, encoding='utf-8') as metadata_file: - try: - json_content = json.load(metadata_file) - record_metadata = KeyValueStoreRecordMetadata(**json_content) - - except Exception: - logger.warning( - f'Metadata of key-value store entry "{entry.name}" for store {name or id} could ' - 'not be parsed. The metadata file will be ignored.', - exc_info=True, - ) - - if not record_metadata: - content_type, _ = mimetypes.guess_type(entry.name) - if content_type is None: - content_type = 'application/octet-stream' - - record_metadata = KeyValueStoreRecordMetadata( - key=pathlib.Path(entry.name).stem, - content_type=content_type, - ) - - with open(os.path.join(storage_directory, entry.name), 'rb') as f: - file_content = f.read() - - try: - maybe_parse_body(file_content, record_metadata.content_type) - except Exception: - record_metadata.content_type = 'application/octet-stream' - logger.warning( - f'Key-value store entry "{record_metadata.key}" for store {name or id} could not be parsed.' - 'The entry will be assumed as binary.', - exc_info=True, - ) - - new_client.records[record_metadata.key] = KeyValueStoreRecord( - key=record_metadata.key, - content_type=record_metadata.content_type, - filename=entry.name, - value=file_content, - ) - - return new_client - - -def create_rq_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> RequestQueueClient: - from ._request_queue_client import RequestQueueClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - handled_request_count = 0 - pending_request_count = 0 - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = RequestQueueMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - handled_request_count = resource_info.handled_request_count - pending_request_count = resource_info.pending_request_count - - # Load request entries - entries: dict[str, InternalRequest] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - content = json.load(f) - - request = InternalRequest(**content) - - entries[request.id] = request - - # Create new RQ client - new_client = RequestQueueClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - handled_request_count=handled_request_count, - pending_request_count=pending_request_count, - ) - - new_client.requests.update(entries) - return new_client - - -def _determine_storage_path( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> str | None: - storages_dir = memory_storage_client._get_storage_dir(resource_client_class) # noqa: SLF001 - default_id = memory_storage_client._get_default_storage_id(resource_client_class) # noqa: SLF001 - - # Try to find by name directly from directories - if name: - possible_storage_path = os.path.join(storages_dir, name) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - # If not found, try finding by metadata - if os.access(storages_dir, os.F_OK): - for entry in os.scandir(storages_dir): - if entry.is_dir(): - metadata_path = os.path.join(entry.path, METADATA_FILENAME) - if os.access(metadata_path, os.F_OK): - with open(metadata_path, encoding='utf-8') as metadata_file: - try: - metadata = json.load(metadata_file) - if (id and metadata.get('id') == id) or (name and metadata.get('name') == name): - return entry.path - except Exception: - logger.warning( - f'Metadata of store entry "{entry.name}" for store {name or id} could not be parsed. ' - 'The metadata file will be ignored.', - exc_info=True, - ) - - # Check for default storage directory as a last resort - if id == default_id: - possible_storage_path = os.path.join(storages_dir, default_id) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - return None diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 50c8c7c8d4..0d75b50f9f 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -1,162 +1,137 @@ from __future__ import annotations -import asyncio -import json -import os -import shutil from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_rename, json_dumps -from crawlee.storage_clients._base import DatasetClient as BaseDatasetClient +from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata -from ._creation_management import find_or_create_client_by_id_or_name_inner - if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager - - from httpx import Response - from crawlee._types import JsonSerializable - from crawlee.storage_clients import MemoryStorageClient + from crawlee.configuration import Configuration logger = getLogger(__name__) -class DatasetClient(BaseDatasetClient): - """Subclient for manipulating a single dataset.""" +class MemoryDatasetClient(DatasetClient): + """Memory implementation of the dataset client. + + This client stores dataset items in memory using Python lists and dictionaries. No data is persisted + between process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where persistent + storage is not required. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. It supports all dataset operations including + sorting, filtering, and pagination, but performs them entirely in memory. + """ - _LOCAL_ENTRY_NAME_DIGITS = 9 - """Number of characters of the dataset item file names, e.g.: 000000019.json - 9 digits.""" + _cache_by_name: ClassVar[dict[str, MemoryDatasetClient]] = {} + """A dictionary to cache clients by their names.""" def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - item_count: int = 0, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) + """Initialize a new instance. - self.dataset_entries: dict[str, dict] = {} - self.file_operation_lock = asyncio.Lock() - self.item_count = item_count - - @property - def resource_info(self) -> DatasetMetadata: - """Get the resource info for the dataset client.""" - return DatasetMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - item_count=self.item_count, + Preferably use the `MemoryDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, ) - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.datasets_directory, self.name or self.id) + # List to hold dataset items + self._records = list[dict[str, Any]]() @override - async def get(self) -> DatasetMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None + @property + def metadata(self) -> DatasetMetadata: + return self._metadata @override - async def update(self, *, name: str | None = None) -> DatasetMetadata: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryDatasetClient: + name = name or configuration.default_dataset_id + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + dataset_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + client = cls( + id=dataset_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, ) - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - # Skip if no changes - if name is None: - return existing_dataset_by_id.resource_info - - async with existing_dataset_by_id.file_operation_lock: - # Check that name is not in use already - existing_dataset_by_name = next( - ( - dataset - for dataset in self._memory_storage_client.datasets_handled - if dataset.name and dataset.name.lower() == name.lower() - ), - None, - ) - - if existing_dataset_by_name is not None: - raise_on_duplicate_storage(StorageTypes.DATASET, 'name', name) + # Cache the client by name + cls._cache_by_name[name] = client - previous_dir = existing_dataset_by_id.resource_directory - existing_dataset_by_id.name = name + return client - await force_rename(previous_dir, existing_dataset_by_id.resource_directory) - - # Update timestamps - await existing_dataset_by_id.update_timestamps(has_been_modified=True) + @override + async def drop(self) -> None: + self._records.clear() + self._metadata.item_count = 0 - return existing_dataset_by_id.resource_info + # Remove the client from the cache + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override - async def delete(self) -> None: - dataset = next( - (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.metadata.item_count + + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item) + else: + new_item_count += 1 + await self._push_item(data) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, ) - if dataset is not None: - async with dataset.file_operation_lock: - self._memory_storage_client.datasets_handled.remove(dataset) - dataset.item_count = 0 - dataset.dataset_entries.clear() - - if os.path.exists(dataset.resource_directory): - await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) - @override - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -167,44 +142,48 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - async with existing_dataset_by_id.file_operation_lock: - start, end = existing_dataset_by_id.get_start_and_end_indexes( - max(existing_dataset_by_id.item_count - (offset or 0) - (limit or self._LIST_ITEMS_LIMIT), 0) - if desc - else offset or 0, - limit, + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported ' + f'by the {self.__class__.__name__} client.' ) - items = [] + total = len(self._records) + items = self._records.copy() - for idx in range(start, end): - entry_number = self._generate_local_entry_name(idx) - items.append(existing_dataset_by_id.dataset_entries[entry_number]) + # Apply skip_empty filter if requested + if skip_empty: + items = [item for item in items if item] - await existing_dataset_by_id.update_timestamps(has_been_modified=False) + # Apply sorting + if desc: + items = list(reversed(items)) - if desc: - items.reverse() + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else total] - return DatasetItemsListPage( - count=len(items), - desc=desc or False, - items=items, - limit=limit or self._LIST_ITEMS_LIMIT, - offset=offset or 0, - total=existing_dataset_by_id.item_count, - ) + await self._update_metadata(update_accessed_at=True) + + return DatasetItemsListPage( + count=len(sliced_items), + offset=offset, + limit=limit or (total - offset), + total=total, + desc=desc, + items=sliced_items, + ) @override async def iterate_items( @@ -220,191 +199,66 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - cache_size = 1000 - first_item = offset - - # If there is no limit, set last_item to None until we get the total from the first API response - last_item = None if limit is None else offset + limit - current_offset = first_item - - while last_item is None or current_offset < last_item: - current_limit = cache_size if last_item is None else min(cache_size, last_item - current_offset) - - current_items_page = await self.list_items( - offset=current_offset, - limit=current_limit, - desc=desc, + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' ) - current_offset += current_items_page.count - if last_item is None or current_items_page.total < last_item: - last_item = current_items_page.total - - for item in current_items_page.items: - yield item - - @override - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def push_items( - self, - items: JsonSerializable, - ) -> None: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + items = self._records.copy() - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) + # Apply sorting + if desc: + items = list(reversed(items)) - normalized = self._normalize_items(items) + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else len(items)] - added_ids: list[str] = [] - for entry in normalized: - existing_dataset_by_id.item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) + # Yield items one by one + for item in sliced_items: + if skip_empty and not item: + continue + yield item - existing_dataset_by_id.dataset_entries[idx] = entry - added_ids.append(idx) - - data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - - async with existing_dataset_by_id.file_operation_lock: - await existing_dataset_by_id.update_timestamps(has_been_modified=True) - - await self._persist_dataset_items_to_disk( - data=data_entries, - entity_directory=existing_dataset_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + await self._update_metadata(update_accessed_at=True) - async def _persist_dataset_items_to_disk( + async def _update_metadata( self, *, - data: list[tuple[str, dict]], - entity_directory: str, - persist_storage: bool, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Write dataset items to the disk. - - The function iterates over a list of dataset items, each represented as a tuple of an identifier - and a dictionary, and writes them as individual JSON files in a specified directory. The function - will skip writing if `persist_storage` is False. Before writing, it ensures that the target - directory exists, creating it if necessary. + """Update the dataset metadata with current information. Args: - data: A list of tuples, each containing an identifier (string) and a data dictionary. - entity_directory: The directory path where the dataset items should be stored. - persist_storage: A boolean flag indicating whether the data should be persisted to the disk. + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Save all the new items to the disk - for idx, item in data: - file_path = os.path.join(entity_directory, f'{idx}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(item) - await asyncio.to_thread(f.write, s) - finally: - await asyncio.to_thread(f.close) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the dataset.""" - from ._creation_management import persist_metadata_if_enabled - - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - def get_start_and_end_indexes(self, offset: int, limit: int | None = None) -> tuple[int, int]: - """Calculate the start and end indexes for listing items.""" - actual_limit = limit or self.item_count - start = offset + 1 - end = min(offset + actual_limit, self.item_count) + 1 - return (start, end) - - def _generate_local_entry_name(self, idx: int) -> str: - return str(idx).zfill(self._LOCAL_ENTRY_NAME_DIGITS) - - def _normalize_items(self, items: JsonSerializable) -> list[dict]: - def normalize_item(item: Any) -> dict | None: - if isinstance(item, str): - item = json.loads(item) + now = datetime.now(timezone.utc) - if isinstance(item, list): - received = ',\n'.join(item) - raise TypeError( - f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]' - ) + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count: + self._metadata.item_count = new_item_count - if (not isinstance(item, dict)) and item is not None: - raise TypeError(f'Each dataset item must be a JSON object. Received: {item}') + async def _push_item(self, item: dict[str, Any]) -> None: + """Push a single item to the dataset. - return item - - if isinstance(items, str): - items = json.loads(items) - - result = list(map(normalize_item, items)) if isinstance(items, list) else [normalize_item(items)] - # filter(None, ..) returns items that are True - return list(filter(None, result)) + Args: + item: The data item to add to the dataset. + """ + self._records.append(item) diff --git a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py b/src/crawlee/storage_clients/_memory/_dataset_collection_client.py deleted file mode 100644 index 9e32b4086b..0000000000 --- a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import DatasetCollectionClient as BaseDatasetCollectionClient -from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - -from ._creation_management import get_or_create_inner -from ._dataset_client import DatasetClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class DatasetCollectionClient(BaseDatasetCollectionClient): - """Subclient for manipulating datasets.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[DatasetClient]: - return self._memory_storage_client.datasets_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> DatasetMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=DatasetClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return DatasetListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index ab9def0f06..9b70419142 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -1,425 +1,190 @@ from __future__ import annotations -import asyncio -import io -import os -import shutil +import sys from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body, raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import determine_file_extension, force_remove, force_rename, is_file_or_bytes, json_dumps -from crawlee.storage_clients._base import KeyValueStoreClient as BaseKeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreKeyInfo, - KeyValueStoreListKeysPage, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - - from crawlee.storage_clients import MemoryStorageClient + from crawlee.configuration import Configuration logger = getLogger(__name__) -class KeyValueStoreClient(BaseKeyValueStoreClient): - """Subclient for manipulating a single key-value store.""" +class MemoryKeyValueStoreClient(KeyValueStoreClient): + """Memory implementation of the key-value store client. + + This client stores data in memory as Python dictionaries. No data is persisted between + process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where + persistence is not required. + + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. + """ + + _cache_by_name: ClassVar[dict[str, MemoryKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, ) -> None: - self.id = id or crypto_random_object_id() - self.name = name - - self._memory_storage_client = memory_storage_client - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.records: dict[str, KeyValueStoreRecord] = {} - self.file_operation_lock = asyncio.Lock() - - @property - def resource_info(self) -> KeyValueStoreMetadata: - """Get the resource info for the key-value store client.""" - return KeyValueStoreMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - user_id='1', + """Initialize a new instance. + + Preferably use the `MemoryKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, ) - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.key_value_stores_directory, self.name or self.id) + # Dictionary to hold key-value records with metadata + self._records = dict[str, KeyValueStoreRecord]() @override - async def get(self) -> KeyValueStoreMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None - - @override - async def update(self, *, name: str | None = None) -> KeyValueStoreMetadata: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - # Skip if no changes - if name is None: - return existing_store_by_id.resource_info - - async with existing_store_by_id.file_operation_lock: - # Check that name is not in use already - existing_store_by_name = next( - ( - store - for store in self._memory_storage_client.key_value_stores_handled - if store.name and store.name.lower() == name.lower() - ), - None, - ) - - if existing_store_by_name is not None: - raise_on_duplicate_storage(StorageTypes.KEY_VALUE_STORE, 'name', name) - - previous_dir = existing_store_by_id.resource_directory - existing_store_by_id.name = name - - await force_rename(previous_dir, existing_store_by_id.resource_directory) - - # Update timestamps - await existing_store_by_id.update_timestamps(has_been_modified=True) - - return existing_store_by_id.resource_info - - @override - async def delete(self) -> None: - store = next( - (store for store in self._memory_storage_client.key_value_stores_handled if store.id == self.id), None - ) - - if store is not None: - async with store.file_operation_lock: - self._memory_storage_client.key_value_stores_handled.remove(store) - store.records.clear() - - if os.path.exists(store.resource_directory): - await asyncio.to_thread(shutil.rmtree, store.resource_directory) + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata @override - async def list_keys( - self, + @classmethod + async def open( + cls, *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryKeyValueStoreClient: + name = name or configuration.default_key_value_store_id + + # Check if the client is already cached by name + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + client = cls( + id=id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, ) - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - items: list[KeyValueStoreKeyInfo] = [] + # Cache the client by name + cls._cache_by_name[name] = client - for record in existing_store_by_id.records.values(): - size = len(record.value) - items.append(KeyValueStoreKeyInfo(key=record.key, size=size)) + return client - if len(items) == 0: - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=False, - next_exclusive_start_key=None, - items=items, - ) - - # Lexically sort to emulate the API - items = sorted(items, key=lambda item: item.key) + @override + async def drop(self) -> None: + # Clear all data + self._records.clear() - truncated_items = items - if exclusive_start_key is not None: - key_pos = next((idx for idx, item in enumerate(items) if item.key == exclusive_start_key), None) - if key_pos is not None: - truncated_items = items[(key_pos + 1) :] - - limited_items = truncated_items[:limit] - - last_item_in_store = items[-1] - last_selected_item = limited_items[-1] - is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item - next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item.key - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=not is_last_selected_item_absolutely_last, - next_exclusive_start_key=next_exclusive_start_key, - items=limited_items, - ) + # Remove from cache + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - return await self._get_record_internal(key) + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + await self._update_metadata(update_accessed_at=True) - @override - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - return await self._get_record_internal(key, as_bytes=True) + # Return None if key doesn't exist + return self._records.get(key, None) @override - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - raise NotImplementedError('This method is not supported in memory storage.') + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + size = sys.getsizeof(value) - @override - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + # Create and store the record + record = KeyValueStoreRecord( + key=key, + value=value, + content_type=content_type, + size=size, ) - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - if isinstance(value, io.IOBase): - raise NotImplementedError('File-like values are not supported in local memory storage') - - if content_type is None: - if is_file_or_bytes(value): - content_type = 'application/octet-stream' - elif isinstance(value, str): - content_type = 'text/plain; charset=utf-8' - else: - content_type = 'application/json; charset=utf-8' - - if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str): - s = await json_dumps(value) - value = s.encode('utf-8') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=True) - record = KeyValueStoreRecord(key=key, value=value, content_type=content_type, filename=None) - - old_record = existing_store_by_id.records.get(key) - existing_store_by_id.records[key] = record - - if self._memory_storage_client.persist_storage: - record_filename = self._filename_from_record(record) - record.filename = record_filename + self._records[key] = record - if old_record is not None and self._filename_from_record(old_record) != record_filename: - await existing_store_by_id.delete_persisted_record(old_record) - - await existing_store_by_id.persist_record(record) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def delete_record(self, key: str) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = existing_store_by_id.records.get(key) - - if record is not None: - async with existing_store_by_id.file_operation_lock: - del existing_store_by_id.records[key] - await existing_store_by_id.update_timestamps(has_been_modified=True) - if self._memory_storage_client.persist_storage: - await existing_store_by_id.delete_persisted_record(record) + async def delete_value(self, *, key: str) -> None: + if key in self._records: + del self._records[key] + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def get_public_url(self, key: str) -> str: - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = await self._get_record_internal(key) - - if not record: - raise ValueError(f'Record with key "{key}" was not found.') - - resource_dir = existing_store_by_id.resource_directory - record_filename = self._filename_from_record(record) - record_path = os.path.join(resource_dir, record_filename) - return f'file://{record_path}' - - async def persist_record(self, record: KeyValueStoreRecord) -> None: - """Persist the specified record to the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - record.filename = record_filename - record.content_type = record.content_type or 'application/octet-stream' - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, f'{record_filename}.__metadata__.json') - - # Convert to bytes if string - if isinstance(record.value, str): - record.value = record.value.encode('utf-8') - - f = await asyncio.to_thread(open, record_path, mode='wb') - try: - await asyncio.to_thread(f.write, record.value) - finally: - await asyncio.to_thread(f.close) - - if self._memory_storage_client.write_metadata: - metadata_f = await asyncio.to_thread(open, record_metadata_path, mode='wb') - - try: - record_metadata = KeyValueStoreRecordMetadata(key=record.key, content_type=record.content_type) - await asyncio.to_thread(metadata_f.write, record_metadata.model_dump_json(indent=2).encode('utf-8')) - finally: - await asyncio.to_thread(metadata_f.close) - - async def delete_persisted_record(self, record: KeyValueStoreRecord) -> None: - """Delete the specified record from the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, record_filename + '.__metadata__.json') - - await force_remove(record_path) - await force_remove(record_metadata_path) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the key-value store.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - async def _get_record_internal( + async def iterate_keys( self, - key: str, *, - as_bytes: bool = False, - ) -> KeyValueStoreRecord | None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - stored_record = existing_store_by_id.records.get(key) - - if stored_record is None: - return None - - record = KeyValueStoreRecord( - key=stored_record.key, - value=stored_record.value, - content_type=stored_record.content_type, - filename=stored_record.filename, - ) - - if not as_bytes: - try: - record.value = maybe_parse_body(record.value, str(record.content_type)) - except ValueError: - logger.exception('Error parsing key-value store record') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return record - - def _filename_from_record(self, record: KeyValueStoreRecord) -> str: - if record.filename is not None: - return record.filename + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + await self._update_metadata(update_accessed_at=True) - if not record.content_type or record.content_type == 'application/octet-stream': - return record.key + # Get all keys, sorted alphabetically + keys = sorted(self._records.keys()) - extension = determine_file_extension(record.content_type) + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None: + keys = [k for k in keys if k > exclusive_start_key] + + # Apply limit if provided + if limit is not None: + keys = keys[:limit] + + # Yield metadata for each key + for key in keys: + record = self._records[key] + yield KeyValueStoreRecordMetadata( + key=key, + content_type=record.content_type, + size=record.size, + ) - if record.key.endswith(f'.{extension}'): - return record.key + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for memory key-value stores.') - return f'{record.key}.{extension}' + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the key-value store metadata with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py deleted file mode 100644 index 939780449f..0000000000 --- a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import KeyValueStoreCollectionClient as BaseKeyValueStoreCollectionClient -from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - -from ._creation_management import get_or_create_inner -from ._key_value_store_client import KeyValueStoreClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class KeyValueStoreCollectionClient(BaseKeyValueStoreCollectionClient): - """Subclient for manipulating key-value stores.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[KeyValueStoreClient]: - return self._memory_storage_client.key_value_stores_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> KeyValueStoreMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=KeyValueStoreClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return KeyValueStoreListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_memory_storage_client.py b/src/crawlee/storage_clients/_memory/_memory_storage_client.py deleted file mode 100644 index 8000f41274..0000000000 --- a/src/crawlee/storage_clients/_memory/_memory_storage_client.py +++ /dev/null @@ -1,358 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import os -import shutil -from logging import getLogger -from pathlib import Path -from typing import TYPE_CHECKING, TypeVar - -from typing_extensions import override - -from crawlee._utils.docs import docs_group -from crawlee.configuration import Configuration -from crawlee.storage_clients import StorageClient - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -if TYPE_CHECKING: - from crawlee.storage_clients._base import ResourceClient - - -TResourceClient = TypeVar('TResourceClient', DatasetClient, KeyValueStoreClient, RequestQueueClient) - -logger = getLogger(__name__) - - -@docs_group('Classes') -class MemoryStorageClient(StorageClient): - """Represents an in-memory storage client for managing datasets, key-value stores, and request queues. - - It emulates in-memory storage similar to the Apify platform, supporting both in-memory and local file system-based - persistence. - - The behavior of the storage, such as data persistence and metadata writing, can be customized via initialization - parameters or environment variables. - """ - - _MIGRATING_KEY_VALUE_STORE_DIR_NAME = '__CRAWLEE_MIGRATING_KEY_VALUE_STORE' - """Name of the directory used to temporarily store files during the migration of the default key-value store.""" - - _TEMPORARY_DIR_NAME = '__CRAWLEE_TEMPORARY' - """Name of the directory used to temporarily store files during purges.""" - - _DATASETS_DIR_NAME = 'datasets' - """Name of the directory containing datasets.""" - - _KEY_VALUE_STORES_DIR_NAME = 'key_value_stores' - """Name of the directory containing key-value stores.""" - - _REQUEST_QUEUES_DIR_NAME = 'request_queues' - """Name of the directory containing request queues.""" - - def __init__( - self, - *, - write_metadata: bool, - persist_storage: bool, - storage_dir: str, - default_request_queue_id: str, - default_key_value_store_id: str, - default_dataset_id: str, - ) -> None: - """Initialize a new instance. - - In most cases, you should use the `from_config` constructor to create a new instance based on - the provided configuration. - - Args: - write_metadata: Whether to write metadata to the storage. - persist_storage: Whether to persist the storage. - storage_dir: Path to the storage directory. - default_request_queue_id: The default request queue ID. - default_key_value_store_id: The default key-value store ID. - default_dataset_id: The default dataset ID. - """ - # Set the internal attributes. - self._write_metadata = write_metadata - self._persist_storage = persist_storage - self._storage_dir = storage_dir - self._default_request_queue_id = default_request_queue_id - self._default_key_value_store_id = default_key_value_store_id - self._default_dataset_id = default_dataset_id - - self.datasets_handled: list[DatasetClient] = [] - self.key_value_stores_handled: list[KeyValueStoreClient] = [] - self.request_queues_handled: list[RequestQueueClient] = [] - - self._purged_on_start = False # Indicates whether a purge was already performed on this instance. - self._purge_lock = asyncio.Lock() - - @classmethod - def from_config(cls, config: Configuration | None = None) -> MemoryStorageClient: - """Initialize a new instance based on the provided `Configuration`. - - Args: - config: The `Configuration` instance. Uses the global (default) one if not provided. - """ - config = config or Configuration.get_global_configuration() - - return cls( - write_metadata=config.write_metadata, - persist_storage=config.persist_storage, - storage_dir=config.storage_dir, - default_request_queue_id=config.default_request_queue_id, - default_key_value_store_id=config.default_key_value_store_id, - default_dataset_id=config.default_dataset_id, - ) - - @property - def write_metadata(self) -> bool: - """Whether to write metadata to the storage.""" - return self._write_metadata - - @property - def persist_storage(self) -> bool: - """Whether to persist the storage.""" - return self._persist_storage - - @property - def storage_dir(self) -> str: - """Path to the storage directory.""" - return self._storage_dir - - @property - def datasets_directory(self) -> str: - """Path to the directory containing datasets.""" - return os.path.join(self.storage_dir, self._DATASETS_DIR_NAME) - - @property - def key_value_stores_directory(self) -> str: - """Path to the directory containing key-value stores.""" - return os.path.join(self.storage_dir, self._KEY_VALUE_STORES_DIR_NAME) - - @property - def request_queues_directory(self) -> str: - """Path to the directory containing request queues.""" - return os.path.join(self.storage_dir, self._REQUEST_QUEUES_DIR_NAME) - - @override - def dataset(self, id: str) -> DatasetClient: - return DatasetClient(memory_storage_client=self, id=id) - - @override - def datasets(self) -> DatasetCollectionClient: - return DatasetCollectionClient(memory_storage_client=self) - - @override - def key_value_store(self, id: str) -> KeyValueStoreClient: - return KeyValueStoreClient(memory_storage_client=self, id=id) - - @override - def key_value_stores(self) -> KeyValueStoreCollectionClient: - return KeyValueStoreCollectionClient(memory_storage_client=self) - - @override - def request_queue(self, id: str) -> RequestQueueClient: - return RequestQueueClient(memory_storage_client=self, id=id) - - @override - def request_queues(self) -> RequestQueueCollectionClient: - return RequestQueueCollectionClient(memory_storage_client=self) - - @override - async def purge_on_start(self) -> None: - # Optimistic, non-blocking check - if self._purged_on_start is True: - logger.debug('Storage was already purged on start.') - return - - async with self._purge_lock: - # Another check under the lock just to be sure - if self._purged_on_start is True: - # Mypy doesn't understand that the _purged_on_start can change while we're getting the async lock - return # type: ignore[unreachable] - - await self._purge_default_storages() - self._purged_on_start = True - - def get_cached_resource_client( - self, - resource_client_class: type[TResourceClient], - id: str | None, - name: str | None, - ) -> TResourceClient | None: - """Try to return a resource client from the internal cache.""" - if issubclass(resource_client_class, DatasetClient): - cache = self.datasets_handled - elif issubclass(resource_client_class, KeyValueStoreClient): - cache = self.key_value_stores_handled - elif issubclass(resource_client_class, RequestQueueClient): - cache = self.request_queues_handled - else: - return None - - for storage_client in cache: - if storage_client.id == id or ( - storage_client.name and name and storage_client.name.lower() == name.lower() - ): - return storage_client - - return None - - def add_resource_client_to_cache(self, resource_client: ResourceClient) -> None: - """Add a new resource client to the internal cache.""" - if isinstance(resource_client, DatasetClient): - self.datasets_handled.append(resource_client) - if isinstance(resource_client, KeyValueStoreClient): - self.key_value_stores_handled.append(resource_client) - if isinstance(resource_client, RequestQueueClient): - self.request_queues_handled.append(resource_client) - - async def _purge_default_storages(self) -> None: - """Clean up the storage directories, preparing the environment for a new run. - - It aims to remove residues from previous executions to avoid data contamination between runs. - - It specifically targets: - - The local directory containing the default dataset. - - All records from the default key-value store in the local directory, except for the 'INPUT' key. - - The local directory containing the default request queue. - """ - # Key-value stores - if await asyncio.to_thread(os.path.exists, self.key_value_stores_directory): - key_value_store_folders = await asyncio.to_thread(os.scandir, self.key_value_stores_directory) - for key_value_store_folder in key_value_store_folders: - if key_value_store_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ) or key_value_store_folder.name.startswith('__OLD'): - await self._batch_remove_files(key_value_store_folder.path) - elif key_value_store_folder.name == self._default_key_value_store_id: - await self._handle_default_key_value_store(key_value_store_folder.path) - - # Datasets - if await asyncio.to_thread(os.path.exists, self.datasets_directory): - dataset_folders = await asyncio.to_thread(os.scandir, self.datasets_directory) - for dataset_folder in dataset_folders: - if dataset_folder.name == self._default_dataset_id or dataset_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(dataset_folder.path) - - # Request queues - if await asyncio.to_thread(os.path.exists, self.request_queues_directory): - request_queue_folders = await asyncio.to_thread(os.scandir, self.request_queues_directory) - for request_queue_folder in request_queue_folders: - if request_queue_folder.name == self._default_request_queue_id or request_queue_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(request_queue_folder.path) - - async def _handle_default_key_value_store(self, folder: str) -> None: - """Manage the cleanup of the default key-value store. - - It removes all files to ensure a clean state except for a set of predefined input keys (`possible_input_keys`). - - Args: - folder: Path to the default key-value store directory to clean. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - temporary_path = os.path.normpath(os.path.join(folder, '..', self._MIGRATING_KEY_VALUE_STORE_DIR_NAME)) - - # For optimization, we want to only attempt to copy a few files from the default key-value store - possible_input_keys = [ - 'INPUT', - 'INPUT.json', - 'INPUT.bin', - 'INPUT.txt', - ] - - if folder_exists: - # Create a temporary folder to save important files in - Path(temporary_path).mkdir(parents=True, exist_ok=True) - - # Go through each file and save the ones that are important - for entity in possible_input_keys: - original_file_path = os.path.join(folder, entity) - temp_file_path = os.path.join(temporary_path, entity) - with contextlib.suppress(Exception): - await asyncio.to_thread(os.rename, original_file_path, temp_file_path) - - # Remove the original folder and all its content - counter = 0 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - done = False - try: - while not done: - await asyncio.to_thread(os.rename, folder, temp_path_for_old_folder) - done = True - except Exception: - counter += 1 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - - # Replace the temporary folder with the original folder - await asyncio.to_thread(os.rename, temporary_path, folder) - - # Remove the old folder - await self._batch_remove_files(temp_path_for_old_folder) - - async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: - """Remove a folder and its contents in batches to minimize blocking time. - - This method first renames the target folder to a temporary name, then deletes the temporary folder, - allowing the file system operations to proceed without hindering other asynchronous tasks. - - Args: - folder: The directory path to remove. - counter: A counter used for generating temporary directory names in case of conflicts. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - - if folder_exists: - temporary_folder = ( - folder - if os.path.basename(folder).startswith(f'{self._TEMPORARY_DIR_NAME}_') - else os.path.normpath(os.path.join(folder, '..', f'{self._TEMPORARY_DIR_NAME}_{counter}')) - ) - - try: - # Rename the old folder to the new one to allow background deletions - await asyncio.to_thread(os.rename, folder, temporary_folder) - except Exception: - # Folder exists already, try again with an incremented counter - return await self._batch_remove_files(folder, counter + 1) - - await asyncio.to_thread(shutil.rmtree, temporary_folder, ignore_errors=True) - return None - - def _get_default_storage_id(self, storage_client_class: type[TResourceClient]) -> str: - """Get the default storage ID based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self._default_dataset_id - - if issubclass(storage_client_class, KeyValueStoreClient): - return self._default_key_value_store_id - - if issubclass(storage_client_class, RequestQueueClient): - return self._default_request_queue_id - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') - - def _get_storage_dir(self, storage_client_class: type[TResourceClient]) -> str: - """Get the storage directory based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self.datasets_directory - - if issubclass(storage_client_class, KeyValueStoreClient): - return self.key_value_stores_directory - - if issubclass(storage_client_class, RequestQueueClient): - return self.request_queues_directory - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 477d53df07..c8e58e0515 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -1,558 +1,363 @@ from __future__ import annotations -import asyncio -import os -import shutil from datetime import datetime, timezone -from decimal import Decimal from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from sortedcollections import ValueSortedDict from typing_extensions import override -from crawlee._types import StorageTypes +from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_remove, force_rename, json_dumps -from crawlee._utils.requests import unique_key_to_request_id -from crawlee.storage_clients._base import RequestQueueClient as BaseRequestQueueClient +from crawlee.storage_clients._base import RequestQueueClient from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - InternalRequest, + AddRequestsResponse, ProcessedRequest, - ProlongRequestLockResponse, - RequestQueueHead, - RequestQueueHeadWithLocks, RequestQueueMetadata, - UnprocessedRequest, ) -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled - if TYPE_CHECKING: from collections.abc import Sequence - from sortedcontainers import SortedDict + from crawlee.configuration import Configuration - from crawlee import Request +logger = getLogger(__name__) - from ._memory_storage_client import MemoryStorageClient -logger = getLogger(__name__) +class MemoryRequestQueueClient(RequestQueueClient): + """Memory implementation of the request queue client. + This client stores requests in memory using a Python list and dictionary. No data is persisted between + process runs, which means all requests are lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler runs where persistence + is not required. -class RequestQueueClient(BaseRequestQueueClient): - """Subclient for manipulating a single request queue.""" + This client provides fast access to request data but is limited by available memory and + does not support data sharing across different processes. + """ + + _cache_by_name: ClassVar[dict[str, MemoryRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - handled_request_count: int = 0, - pending_request_count: int = 0, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - self.handled_request_count = handled_request_count - self.pending_request_count = pending_request_count - - self.requests: SortedDict[str, InternalRequest] = ValueSortedDict( - lambda request: request.order_no or -float('inf') - ) - self.file_operation_lock = asyncio.Lock() - self._last_used_timestamp = Decimal(0) + """Initialize a new instance. - self._in_progress = set[str]() - - @property - def resource_info(self) -> RequestQueueMetadata: - """Get the resource info for the request queue client.""" - return RequestQueueMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - had_multiple_clients=False, - handled_request_count=self.handled_request_count, - pending_request_count=self.pending_request_count, - stats={}, - total_request_count=len(self.requests), - user_id='1', - resource_directory=self.resource_directory, - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.request_queues_directory, self.name or self.id) - - @override - async def get(self) -> RequestQueueMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + Preferably use the `MemoryRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, ) - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info + # List to hold RQ items + self._records = list[Request]() - return None + # Dictionary to track in-progress requests (fetched but not yet handled or reclaimed) + self._in_progress = dict[str, Request]() @override - async def update(self, *, name: str | None = None) -> RequestQueueMetadata: - # Check by id - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - # Skip if no changes - if name is None: - return existing_queue_by_id.resource_info - - async with existing_queue_by_id.file_operation_lock: - # Check that name is not in use already - existing_queue_by_name = next( - ( - queue - for queue in self._memory_storage_client.request_queues_handled - if queue.name and queue.name.lower() == name.lower() - ), - None, - ) - - if existing_queue_by_name is not None: - raise_on_duplicate_storage(StorageTypes.REQUEST_QUEUE, 'name', name) - - previous_dir = existing_queue_by_id.resource_directory - existing_queue_by_id.name = name - - await force_rename(previous_dir, existing_queue_by_id.resource_directory) - - # Update timestamps - await existing_queue_by_id.update_timestamps(has_been_modified=True) - - return existing_queue_by_id.resource_info + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata @override - async def delete(self) -> None: - queue = next( - (queue for queue in self._memory_storage_client.request_queues_handled if queue.id == self.id), - None, + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryRequestQueueClient: + name = name or configuration.default_request_queue_id + + # Check if the client is already cached by name + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, ) - if queue is not None: - async with queue.file_operation_lock: - self._memory_storage_client.request_queues_handled.remove(queue) - queue.pending_request_count = 0 - queue.handled_request_count = 0 - queue.requests.clear() + # Cache the client by name + cls._cache_by_name[name] = client - if os.path.exists(queue.resource_directory): - await asyncio.to_thread(shutil.rmtree, queue.resource_directory) + return client @override - async def list_head(self, *, limit: int | None = None, skip_in_progress: bool = False) -> RequestQueueHead: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - requests: list[Request] = [] - - # Iterate all requests in the queue which have sorted key larger than infinity, which means - # `order_no` is not `None`. This will iterate them in order of `order_no`. - for request_key in existing_queue_by_id.requests.irange_key( # type: ignore[attr-defined] # irange_key is a valid SortedDict method but not recognized by mypy - min_key=-float('inf'), inclusive=(False, True) - ): - if len(requests) == limit: - break + async def drop(self) -> None: + # Clear all data + self._records.clear() + self._in_progress.clear() - if skip_in_progress and request_key in existing_queue_by_id._in_progress: # noqa: SLF001 - continue - internal_request = existing_queue_by_id.requests.get(request_key) - - # Check that the request still exists and was not handled, - # in case something deleted it or marked it as handled concurrenctly - if internal_request and not internal_request.handled_at: - requests.append(internal_request.to_request()) - - return RequestQueueHead( - limit=limit, - had_multiple_clients=False, - queue_modified_at=existing_queue_by_id._modified_at, # noqa: SLF001 - items=requests, - ) + # Remove from cache + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - result = await self.list_head(limit=limit, skip_in_progress=True) - - for item in result.items: - existing_queue_by_id._in_progress.add(item.id) # noqa: SLF001 - - return RequestQueueHeadWithLocks( - queue_has_locked_requests=len(existing_queue_by_id._in_progress) > 0, # noqa: SLF001 - lock_secs=lock_secs, - limit=result.limit, - had_multiple_clients=result.had_multiple_clients, - queue_modified_at=result.queue_modified_at, - items=result.items, - ) - - @override - async def add_request( + async def add_batch_of_requests( self, - request: Request, + requests: Sequence[Request], *, forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. - internal_request = await self._create_internal_request(request, forefront) + Returns: + Response containing information about the added requests. + """ + processed_requests = [] + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() - async with existing_queue_by_id.file_operation_lock: - existing_internal_request_with_id = existing_queue_by_id.requests.get(internal_request.id) + # Check if the request is already in the queue by unique_key + existing_request = next((r for r in self._records if r.unique_key == request.unique_key), None) - # We already have the request present, so we return information about it - if existing_internal_request_with_id is not None: - await existing_queue_by_id.update_timestamps(has_been_modified=False) + was_already_present = existing_request is not None + was_already_handled = was_already_present and existing_request and existing_request.handled_at is not None - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=existing_internal_request_with_id.handled_at is not None, + # If the request is already in the queue and handled, don't add it again + if was_already_handled: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) ) - - existing_queue_by_id.requests[internal_request.id] = internal_request - if internal_request.handled_at: - existing_queue_by_id.handled_request_count += 1 + continue + + # If the request is already in the queue but not handled, update it + if was_already_present: + # Update the existing request with any new data + for idx, rec in enumerate(self._records): + if rec.unique_key == request.unique_key: + self._records[idx] = request + break else: - existing_queue_by_id.pending_request_count += 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, + # Add the new request to the queue + if forefront: + self._records.insert(0, request) + else: + self._records.append(request) + + # Update metadata counts + self._metadata.total_request_count += 1 + self._metadata.pending_request_count += 1 + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=was_already_present, + was_already_handled=False, + ) ) - # We return was_already_handled=False even though the request may have been added as handled, - # because that's how API behaves. - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=False, - was_already_handled=False, - ) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) - @override - async def get_request(self, request_id: str) -> Request | None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) + Returns: + The request or `None` if there are no more pending requests. + """ + # Find the first request that's not handled or in progress + for request in self._records: + if request.handled_at is None and request.id not in self._in_progress: + # Mark as in progress + self._in_progress[request.id] = request + return request - internal_request = existing_queue_by_id.requests.get(request_id) - return internal_request.to_request() if internal_request else None + return None @override - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. - internal_request = await self._create_internal_request(request, forefront) + Args: + request_id: ID of the request to retrieve. - # First we need to check the existing request to be able to return information about its handled state. - existing_internal_request = existing_queue_by_id.requests.get(internal_request.id) + Returns: + The retrieved request, or None, if it did not exist. + """ + # Check in-progress requests first + if request_id in self._in_progress: + return self._in_progress[request_id] - # Undefined means that the request is not present in the queue. - # We need to insert it, to behave the same as API. - if existing_internal_request is None: - return await self.add_request(request, forefront=forefront) + # Otherwise search in the records + for request in self._records: + if request.id == request_id: + return request - async with existing_queue_by_id.file_operation_lock: - # When updating the request, we need to make sure that - # the handled counts are updated correctly in all cases. - existing_queue_by_id.requests[internal_request.id] = internal_request + return None - pending_count_adjustment = 0 - is_request_handled_state_changing = existing_internal_request.handled_at != internal_request.handled_at + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. - request_was_handled_before_update = existing_internal_request.handled_at is not None + Handled requests will never again be returned by the `fetch_next_request` method. - # We add 1 pending request if previous state was handled - if is_request_handled_state_changing: - pending_count_adjustment = 1 if request_was_handled_before_update else -1 + Args: + request: The request to mark as handled. - existing_queue_by_id.pending_request_count += pending_count_adjustment - existing_queue_by_id.handled_request_count -= pending_count_adjustment - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None - if request.handled_at is not None: - existing_queue_by_id._in_progress.discard(request.id) # noqa: SLF001 + # Set handled_at timestamp if not already set + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=request_was_handled_before_update, - ) + # Update the request in records + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records[idx] = request + break - @override - async def delete_request(self, request_id: str) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + # Remove from in-progress + del self._in_progress[request.id] - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + # Update metadata counts + self._metadata.handled_request_count += 1 + self._metadata.pending_request_count -= 1 - async with existing_queue_by_id.file_operation_lock: - internal_request = existing_queue_by_id.requests.get(request_id) + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) - if internal_request: - del existing_queue_by_id.requests[request_id] - if internal_request.handled_at: - existing_queue_by_id.handled_request_count -= 1 - else: - existing_queue_by_id.pending_request_count -= 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._delete_request_file_from_storage( - entity_directory=existing_queue_by_id.resource_directory, - request_id=request_id, - ) + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) @override - async def prolong_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - return ProlongRequestLockResponse(lock_expires_at=datetime.now(timezone.utc)) + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. - @override - async def delete_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - ) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + The request will be returned for processing later again by another call to `fetch_next_request`. - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. - existing_queue_by_id._in_progress.discard(request_id) # noqa: SLF001 + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None - @override - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> BatchRequestsOperationResponse: - processed_requests = list[ProcessedRequest]() - unprocessed_requests = list[UnprocessedRequest]() + # Remove from in-progress + del self._in_progress[request.id] - for request in requests: - try: - processed_request = await self.add_request(request, forefront=forefront) - processed_requests.append( - ProcessedRequest( - id=processed_request.id, - unique_key=processed_request.unique_key, - was_already_present=processed_request.was_already_present, - was_already_handled=processed_request.was_already_handled, - ) - ) - except Exception as exc: # noqa: PERF203 - logger.warning(f'Error adding request to the queue: {exc}') - unprocessed_requests.append( - UnprocessedRequest( - unique_key=request.unique_key, - url=request.url, - method=request.method, - ) - ) + # If forefront is true, move the request to the beginning of the queue + if forefront: + # First remove the request from its current position + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records.pop(idx) + break - return BatchRequestsOperationResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, + # Then insert it at the beginning + self._records.insert(0, request) + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, ) @override - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - raise NotImplementedError('This method is not supported in memory storage.') - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the request queue.""" - self._accessed_at = datetime.now(timezone.utc) + async def is_empty(self) -> bool: + """Check if the queue is empty. - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) + Returns: + True if the queue is empty, False otherwise. + """ + await self._update_metadata(update_accessed_at=True) - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) + # Queue is empty if there are no pending requests + pending_requests = [r for r in self._records if r.handled_at is None] + return len(pending_requests) == 0 - async def _persist_single_request_to_storage( + async def _update_metadata( self, *, - request: InternalRequest, - entity_directory: str, - persist_storage: bool, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Update or writes a single request item to the disk. - - This function writes a given request dictionary to a JSON file, named after the request's ID, - within a specified directory. The writing process is skipped if `persist_storage` is False. - Before writing, it ensures that the target directory exists, creating it if necessary. + """Update the request queue metadata with current information. Args: - request: The dictionary containing the request data. - entity_directory: The directory path where the request file should be stored. - persist_storage: A boolean flag indicating whether the request should be persisted to the disk. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the request to the file - file_path = os.path.join(entity_directory, f'{request.id}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(request.model_dump()) - await asyncio.to_thread(f.write, s) - finally: - f.close() - - async def _delete_request_file_from_storage(self, *, request_id: str, entity_directory: str) -> None: - """Delete a specific request item from the disk. - - This function removes a file representing a request, identified by the request's ID, from a - specified directory. Before attempting to remove the file, it ensures that the target directory - exists, creating it if necessary. - - Args: - request_id: The identifier of the request to be deleted. - entity_directory: The directory path where the request file is stored. - """ - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - file_path = os.path.join(entity_directory, f'{request_id}.json') - await force_remove(file_path) - - async def _create_internal_request(self, request: Request, forefront: bool | None) -> InternalRequest: - order_no = self._calculate_order_no(request, forefront) - id = unique_key_to_request_id(request.unique_key) - - if request.id is not None and request.id != id: - logger.warning( - f'The request ID does not match the ID from the unique_key (request.id={request.id}, id={id}).' - ) - - return InternalRequest.from_request(request=request, id=id, order_no=order_no) - - def _calculate_order_no(self, request: Request, forefront: bool | None) -> Decimal | None: - if request.handled_at is not None: - return None - - # Get the current timestamp in milliseconds - timestamp = Decimal(str(datetime.now(tz=timezone.utc).timestamp())) * Decimal(1000) - timestamp = round(timestamp, 6) - - # Make sure that this timestamp was not used yet, so that we have unique order_nos - if timestamp <= self._last_used_timestamp: - timestamp = self._last_used_timestamp + Decimal('0.000001') - - self._last_used_timestamp = timestamp + now = datetime.now(timezone.utc) - return -timestamp if forefront else timestamp + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py b/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py deleted file mode 100644 index 2f2df2be89..0000000000 --- a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import RequestQueueCollectionClient as BaseRequestQueueCollectionClient -from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - -from ._creation_management import get_or_create_inner -from ._request_queue_client import RequestQueueClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class RequestQueueCollectionClient(BaseRequestQueueCollectionClient): - """Subclient for manipulating request queues.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[RequestQueueClient]: - return self._memory_storage_client.request_queues_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> RequestQueueMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=RequestQueueClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return RequestQueueListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py new file mode 100644 index 0000000000..6123a6ca53 --- /dev/null +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient + + +class MemoryStorageClient(StorageClient): + """Memory storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index f016e24730..f680ba945f 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from typing import Annotated, Any, Generic @@ -26,10 +26,19 @@ class StorageMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True, extra='allow') id: Annotated[str, Field(alias='id')] - name: Annotated[str | None, Field(alias='name', default='')] + """The unique identifier of the storage.""" + + name: Annotated[str, Field(alias='name', default='default')] + """The name of the storage.""" + accessed_at: Annotated[datetime, Field(alias='accessedAt')] + """The timestamp when the storage was last accessed.""" + created_at: Annotated[datetime, Field(alias='createdAt')] + """The timestamp when the storage was created.""" + modified_at: Annotated[datetime, Field(alias='modifiedAt')] + """The timestamp when the storage was last modified.""" @docs_group('Data structures') @@ -39,6 +48,7 @@ class DatasetMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) item_count: Annotated[int, Field(alias='itemCount')] + """The number of items in the dataset.""" @docs_group('Data structures') @@ -47,8 +57,6 @@ class KeyValueStoreMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) - user_id: Annotated[str, Field(alias='userId')] - @docs_group('Data structures') class RequestQueueMetadata(StorageMetadata): @@ -57,44 +65,51 @@ class RequestQueueMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + handled_request_count: Annotated[int, Field(alias='handledRequestCount')] + """The number of requests that have been handled from the queue.""" + pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] + """The number of requests that are still pending in the queue.""" + stats: Annotated[dict, Field(alias='stats')] + """Statistics about the request queue, TODO?""" + total_request_count: Annotated[int, Field(alias='totalRequestCount')] - user_id: Annotated[str, Field(alias='userId')] - resource_directory: Annotated[str, Field(alias='resourceDirectory')] + """The total number of requests that have been added to the queue.""" @docs_group('Data structures') -class KeyValueStoreRecord(BaseModel, Generic[KvsValueType]): - """Model for a key-value store record.""" +class KeyValueStoreRecordMetadata(BaseModel): + """Model for a key-value store record metadata.""" model_config = ConfigDict(populate_by_name=True) key: Annotated[str, Field(alias='key')] - value: Annotated[KvsValueType, Field(alias='value')] - content_type: Annotated[str | None, Field(alias='contentType', default=None)] - filename: Annotated[str | None, Field(alias='filename', default=None)] + """The key of the record. + A unique identifier for the record in the key-value store. + """ -@docs_group('Data structures') -class KeyValueStoreRecordMetadata(BaseModel): - """Model for a key-value store record metadata.""" + content_type: Annotated[str, Field(alias='contentType')] + """The MIME type of the record. - model_config = ConfigDict(populate_by_name=True) + Describe the format and type of data stored in the record, following the MIME specification. + """ - key: Annotated[str, Field(alias='key')] - content_type: Annotated[str, Field(alias='contentType')] + size: Annotated[int, Field(alias='size')] + """The size of the record in bytes.""" @docs_group('Data structures') -class KeyValueStoreKeyInfo(BaseModel): - """Model for a key-value store key info.""" +class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): + """Model for a key-value store record.""" model_config = ConfigDict(populate_by_name=True) - key: Annotated[str, Field(alias='key')] - size: Annotated[int, Field(alias='size')] + value: Annotated[KvsValueType, Field(alias='value')] + """The value of the record.""" @docs_group('Data structures') @@ -104,11 +119,22 @@ class KeyValueStoreListKeysPage(BaseModel): model_config = ConfigDict(populate_by_name=True) count: Annotated[int, Field(alias='count')] + """The number of keys returned on this page.""" + limit: Annotated[int, Field(alias='limit')] + """The maximum number of keys to return.""" + is_truncated: Annotated[bool, Field(alias='isTruncated')] - items: Annotated[list[KeyValueStoreKeyInfo], Field(alias='items', default_factory=list)] + """Indicates whether there are more keys to retrieve.""" + exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartKey', default=None)] + """The key from which to start this page of results.""" + next_exclusive_start_key: Annotated[str | None, Field(alias='nextExclusiveStartKey', default=None)] + """The key from which to start the next page of results.""" + + items: Annotated[list[KeyValueStoreRecordMetadata], Field(alias='items', default_factory=list)] + """The list of KVS items metadata returned on this page.""" @docs_group('Data structures') @@ -126,22 +152,31 @@ class RequestQueueHeadState(BaseModel): @docs_group('Data structures') class RequestQueueHead(BaseModel): - """Model for the request queue head.""" + """Model for request queue head. + + Represents a collection of requests retrieved from the beginning of a queue, + including metadata about the queue's state and lock information for the requests. + """ model_config = ConfigDict(populate_by_name=True) limit: Annotated[int | None, Field(alias='limit', default=None)] - had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """The maximum number of requests that were requested from the queue.""" + + had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients', default=False)] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - items: Annotated[list[Request], Field(alias='items', default_factory=list)] + """The timestamp when the queue was last modified.""" + lock_time: Annotated[timedelta | None, Field(alias='lockSecs', default=None)] + """The duration for which the returned requests are locked and cannot be processed by other clients.""" -@docs_group('Data structures') -class RequestQueueHeadWithLocks(RequestQueueHead): - """Model for request queue head with locks.""" + queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests', default=False)] + """Indicates whether the queue contains any locked requests.""" - lock_secs: Annotated[int, Field(alias='lockSecs')] - queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests')] = None + items: Annotated[list[Request], Field(alias='items', default_factory=list[Request])] + """The list of request objects retrieved from the beginning of the queue.""" class _ListPage(BaseModel): @@ -230,13 +265,22 @@ class UnprocessedRequest(BaseModel): @docs_group('Data structures') -class BatchRequestsOperationResponse(BaseModel): - """Response to batch request deletion calls.""" +class AddRequestsResponse(BaseModel): + """Model for a response to add requests to a queue. + + Contains detailed information about the processing results when adding multiple requests + to a queue. This includes which requests were successfully processed and which ones + encountered issues during processing. + """ model_config = ConfigDict(populate_by_name=True) processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] + """Successfully processed requests, including information about whether they were + already present in the queue and whether they had been handled previously.""" + unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] + """Requests that could not be processed, typically due to validation errors or other issues.""" class InternalRequest(BaseModel): @@ -275,3 +319,22 @@ def from_request(cls, request: Request, id: str, order_no: Decimal | None) -> In def to_request(self) -> Request: """Convert the internal request back to a `Request` object.""" return self.request + + +class CachedRequest(BaseModel): + """Pydantic model for cached request information.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Whether the request was already handled.""" + + hydrated: Request | None = None + """The hydrated request object (the original one).""" + + lock_expires_at: datetime | None = None + """The expiration time of the lock on the request.""" + + forefront: bool = False + """Whether the request was added to the forefront of the queue.""" diff --git a/src/crawlee/storages/_base.py b/src/crawlee/storages/_base.py index 08d2cbd7be..8e73326041 100644 --- a/src/crawlee/storages/_base.py +++ b/src/crawlee/storages/_base.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient - from crawlee.storage_clients.models import StorageMetadata + from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata class Storage(ABC): @@ -24,13 +24,8 @@ def name(self) -> str | None: @property @abstractmethod - def storage_object(self) -> StorageMetadata: - """Get the full storage object.""" - - @storage_object.setter - @abstractmethod - def storage_object(self, storage_object: StorageMetadata) -> None: - """Set the full storage object.""" + def metadata(self) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: + """Get the storage metadata.""" @classmethod @abstractmethod diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py deleted file mode 100644 index d7356a98b5..0000000000 --- a/src/crawlee/storages/_creation_management.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, TypeVar - -from crawlee.storage_clients import MemoryStorageClient - -from ._dataset import Dataset -from ._key_value_store import KeyValueStore -from ._request_queue import RequestQueue - -if TYPE_CHECKING: - from crawlee.configuration import Configuration - from crawlee.storage_clients._base import ResourceClient, ResourceCollectionClient, StorageClient - -TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) - - -_creation_lock = asyncio.Lock() -"""Lock for storage creation.""" - -_cache_dataset_by_id: dict[str, Dataset] = {} -_cache_dataset_by_name: dict[str, Dataset] = {} -_cache_kvs_by_id: dict[str, KeyValueStore] = {} -_cache_kvs_by_name: dict[str, KeyValueStore] = {} -_cache_rq_by_id: dict[str, RequestQueue] = {} -_cache_rq_by_name: dict[str, RequestQueue] = {} - - -def _get_from_cache_by_name( - storage_class: type[TResource], - name: str, -) -> TResource | None: - """Try to restore storage from cache by name.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_name.get(name) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_name.get(name) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_name.get(name) - raise ValueError(f'Unknown storage class: {storage_class.__name__}') - - -def _get_from_cache_by_id( - storage_class: type[TResource], - id: str, -) -> TResource | None: - """Try to restore storage from cache by ID.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_id.get(id) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_id.get(id) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_id.get(id) - raise ValueError(f'Unknown storage: {storage_class.__name__}') - - -def _add_to_cache_by_name(name: str, storage: TResource) -> None: - """Add storage to cache by name.""" - if isinstance(storage, Dataset): - _cache_dataset_by_name[name] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_name[name] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_name[name] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _add_to_cache_by_id(id: str, storage: TResource) -> None: - """Add storage to cache by ID.""" - if isinstance(storage, Dataset): - _cache_dataset_by_id[id] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_id[id] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_id[id] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _rm_from_cache_by_id(storage_class: type, id: str) -> None: - """Remove a storage from cache by ID.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_id[id] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_id[id] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_id[id] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided ID was not found ({id}).') from exc - - -def _rm_from_cache_by_name(storage_class: type, name: str) -> None: - """Remove a storage from cache by name.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_name[name] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_name[name] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_name[name] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided name was not found ({name}).') from exc - - -def _get_default_storage_id(configuration: Configuration, storage_class: type[TResource]) -> str: - if issubclass(storage_class, Dataset): - return configuration.default_dataset_id - if issubclass(storage_class, KeyValueStore): - return configuration.default_key_value_store_id - if issubclass(storage_class, RequestQueue): - return configuration.default_request_queue_id - - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - - -async def open_storage( - *, - storage_class: type[TResource], - id: str | None, - name: str | None, - configuration: Configuration, - storage_client: StorageClient, -) -> TResource: - """Open either a new storage or restore an existing one and return it.""" - # Try to restore the storage from cache by name - if name: - cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name) - if cached_storage: - return cached_storage - - default_id = _get_default_storage_id(configuration, storage_class) - - if not id and not name: - id = default_id - - # Find out if the storage is a default on memory storage - is_default_on_memory = id == default_id and isinstance(storage_client, MemoryStorageClient) - - # Try to restore storage from cache by ID - if id: - cached_storage = _get_from_cache_by_id(storage_class=storage_class, id=id) - if cached_storage: - return cached_storage - - # Purge on start if configured - if configuration.purge_on_start: - await storage_client.purge_on_start() - - # Lock and create new storage - async with _creation_lock: - if id and not is_default_on_memory: - resource_client = _get_resource_client(storage_class, storage_client, id) - storage_object = await resource_client.get() - if not storage_object: - raise RuntimeError(f'{storage_class.__name__} with id "{id}" does not exist!') - - elif is_default_on_memory: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name, id=id) - - else: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name) - - storage = storage_class.from_storage_object(storage_client=storage_client, storage_object=storage_object) - - # Cache the storage by ID and name - _add_to_cache_by_id(storage.id, storage) - if storage.name is not None: - _add_to_cache_by_name(storage.name, storage) - - return storage - - -def remove_storage_from_cache( - *, - storage_class: type, - id: str | None = None, - name: str | None = None, -) -> None: - """Remove a storage from cache by ID or name.""" - if id: - _rm_from_cache_by_id(storage_class=storage_class, id=id) - - if name: - _rm_from_cache_by_name(storage_class=storage_class, name=name) - - -def _get_resource_client( - storage_class: type[TResource], - storage_client: StorageClient, - id: str, -) -> ResourceClient: - if issubclass(storage_class, Dataset): - return storage_client.dataset(id) - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_store(id) - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queue(id) - - raise ValueError(f'Unknown storage class label: {storage_class.__name__}') - - -def _get_resource_collection_client( - storage_class: type, - storage_client: StorageClient, -) -> ResourceCollectionClient: - if issubclass(storage_class, Dataset): - return storage_client.datasets() - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_stores() - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queues() - - raise ValueError(f'Unknown storage class: {storage_class.__name__}') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 7cb58ae817..413290faab 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -1,243 +1,101 @@ from __future__ import annotations -import csv -import io -import json import logging -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, TextIO, TypedDict, cast +from io import StringIO +from typing import TYPE_CHECKING, overload -from typing_extensions import NotRequired, Required, Unpack, override +from typing_extensions import override from crawlee import service_locator -from crawlee._utils.byte_size import ByteSize from crawlee._utils.docs import docs_group -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import DatasetMetadata, StorageMetadata +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from ._base import Storage from ._key_value_store import KeyValueStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable + from collections.abc import AsyncIterator + from typing import Any, ClassVar, Literal + + from typing_extensions import Unpack - from crawlee._types import JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient - from crawlee.storage_clients.models import DatasetItemsListPage - -logger = logging.getLogger(__name__) - - -class GetDataKwargs(TypedDict): - """Keyword arguments for dataset's `get_data` method.""" - - offset: NotRequired[int] - """Skip the specified number of items at the start.""" - - limit: NotRequired[int] - """The maximum number of items to retrieve. Unlimited if None.""" - - clean: NotRequired[bool] - """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" - - desc: NotRequired[bool] - """Set to True to sort results in descending order.""" - - fields: NotRequired[list[str]] - """Fields to include in each item. Sorts fields as specified if provided.""" - - omit: NotRequired[list[str]] - """Fields to exclude from each item.""" - - unwind: NotRequired[str] - """Unwind items by a specified array field, turning each element into a separate item.""" - - skip_empty: NotRequired[bool] - """Exclude empty items from the results if True.""" - - skip_hidden: NotRequired[bool] - """Exclude fields starting with '#' if True.""" - - flatten: NotRequired[list[str]] - """Field to be flattened in returned items.""" - - view: NotRequired[str] - """Specify the dataset view to be used.""" - - -class ExportToKwargs(TypedDict): - """Keyword arguments for dataset's `export_to` method.""" - - key: Required[str] - """The key under which to save the data.""" - - content_type: NotRequired[Literal['json', 'csv']] - """The format in which to export the data. Either 'json' or 'csv'.""" - - to_key_value_store_id: NotRequired[str] - """ID of the key-value store to save the exported file.""" - - to_key_value_store_name: NotRequired[str] - """Name of the key-value store to save the exported file.""" - - -class ExportDataJsonKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_json` method.""" - - skipkeys: NotRequired[bool] - """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped - instead of raising a `TypeError`.""" - - ensure_ascii: NotRequired[bool] - """Determines if non-ASCII characters should be escaped in the output JSON string.""" - - check_circular: NotRequired[bool] - """If False (default: True), skips the circular reference check for container types. A circular reference will - result in a `RecursionError` or worse if unchecked.""" - - allow_nan: NotRequired[bool] - """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply - with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" - - cls: NotRequired[type[json.JSONEncoder]] - """Allows specifying a custom JSON encoder.""" - - indent: NotRequired[int] - """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" - - separators: NotRequired[tuple[str, str]] - """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') - otherwise.""" - - default: NotRequired[Callable] - """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version - of the object or raise a `TypeError`.""" - - sort_keys: NotRequired[bool] - """Specifies whether the output JSON object should have keys sorted alphabetically.""" - - -class ExportDataCsvKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_csv` method.""" - - dialect: NotRequired[str] - """Specifies a dialect to be used in CSV parsing and writing.""" - - delimiter: NotRequired[str] - """A one-character string used to separate fields. Defaults to ','.""" - - doublequote: NotRequired[bool] - """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; - when False, the `escapechar` is used as a prefix. Defaults to True.""" - - escapechar: NotRequired[str] - """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` - if `doublequote` is False. Defaults to None, disabling escaping.""" - - lineterminator: NotRequired[str] - """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" + from crawlee.storage_clients._base import DatasetClient + from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata - quotechar: NotRequired[str] - """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, - or fields containing new-line characters. Defaults to '\"'.""" + from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs - quoting: NotRequired[int] - """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of - the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" - - skipinitialspace: NotRequired[bool] - """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" - - strict: NotRequired[bool] - """When True, raises an exception on bad CSV input. Defaults to False.""" +logger = logging.getLogger(__name__) @docs_group('Classes') class Dataset(Storage): - """Represents an append-only structured storage, ideal for tabular data similar to database tables. + """Dataset is a storage for managing structured tabular data. - The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes - (columns) across the dataset. It operates in an append-only mode, allowing new records to be added, but not - modified or deleted. This makes it particularly useful for storing results from web crawling operations. + The dataset class provides a high-level interface for storing and retrieving structured data + with consistent schema, similar to database tables or spreadsheets. It abstracts the underlying + storage implementation details, offering a consistent API regardless of where the data is + physically stored. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. - - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/datasets/{DATASET_ID}/{INDEX}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{DATASET_ID}`: Specifies the dataset, either "default" or a custom dataset ID. - - `{INDEX}`: Represents the zero-based index of the record within the dataset. + Dataset operates in an append-only mode, allowing new records to be added but not modified + or deleted after creation. This makes it particularly suitable for storing crawling results + and other data that should be immutable once collected. - To open a dataset, use the `open` class method by specifying an `id`, `name`, or `configuration`. If none are - provided, the default dataset for the current crawler run is used. Attempting to open a dataset by `id` that does - not exist will raise an error; however, if accessed by `name`, the dataset will be created if it doesn't already - exist. + The class provides methods for adding data, retrieving data with various filtering options, + and exporting data to different formats. You can create a dataset using the `open` class method, + specifying either a name or ID. The underlying storage implementation is determined by + the configured storage client. ### Usage ```python from crawlee.storages import Dataset + # Open a dataset dataset = await Dataset.open(name='my_dataset') - ``` - """ - _MAX_PAYLOAD_SIZE = ByteSize.from_mb(9) - """Maximum size for a single payload.""" + # Add data + await dataset.push_data({'title': 'Example Product', 'price': 99.99}) - _SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% - """Percentage buffer to reduce payload limit slightly for safety.""" + # Retrieve filtered data + results = await dataset.get_data(limit=10, desc=True) + + # Export data + await dataset.export_to('results.json', content_type='json') + ``` + """ - _EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT) - """Calculated payload limit considering safety buffer.""" + _cache_by_id: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets by their IDs.""" - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + _cache_by_name: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets by their names.""" - # Get resource clients from the storage client. - self._resource_client = storage_client.dataset(self._id) - self._resource_collection_client = storage_client.datasets() + def __init__(self, client: DatasetClient) -> None: + """Initialize a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> Dataset: - """Initialize a new instance of Dataset from a storage metadata object.""" - dataset = Dataset( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `Dataset.open` constructor to create a new instance. - dataset.storage_object = storage_object - return dataset + Args: + client: An instance of a dataset client. + """ + self._client = client - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.metadata.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + @property + def metadata(self) -> DatasetMetadata: + return self._client.metadata @override @classmethod @@ -249,27 +107,45 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + # Check if dataset is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] - return await open_storage( - storage_class=cls, + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client + + client = await storage_client.open_dataset_client( id=id, name=name, configuration=configuration, - storage_client=storage_client, ) + dataset = cls(client) + + # Cache the dataset by id and name if available + if dataset.id: + cls._cache_by_id[dataset.id] = dataset + if dataset.name: + cls._cache_by_name[dataset.name] = dataset + + return dataset + @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + await self._client.drop() - async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwargs]) -> None: + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: """Store an object or an array of objects to the dataset. The size of the data is limited by the receiving API and therefore `push_data()` will only @@ -279,127 +155,65 @@ async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwarg Args: data: A JSON serializable data structure to be stored in the dataset. The JSON representation of each item must be smaller than 9MB. - kwargs: Keyword arguments for the storage client method. """ - # Handle singular items - if not isinstance(data, list): - items = await self.check_and_serialize(data) - return await self._resource_client.push_items(items, **kwargs) + await self._client.push_data(data=data) - # Handle lists - payloads_generator = (await self.check_and_serialize(item, index) for index, item in enumerate(data)) - - # Invoke client in series to preserve the order of data - async for items in self._chunk_by_size(payloads_generator): - await self._resource_client.push_items(items, **kwargs) - - return None - - async def get_data(self, **kwargs: Unpack[GetDataKwargs]) -> DatasetItemsListPage: - """Retrieve dataset items based on filtering, sorting, and pagination parameters. + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + """Retrieve a paginated list of items from a dataset based on various filtering parameters. - This method allows customization of the data retrieval process from a dataset, supporting operations such as - field selection, ordering, and skipping specific records based on provided parameters. + This method provides the flexibility to filter, sort, and modify the appearance of dataset items + when listed. Each parameter modifies the result set according to its purpose. The method also + supports pagination through 'offset' and 'limit' parameters. Args: - kwargs: Keyword arguments for the storage client method. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. + flatten: Fields to be flattened in returned items. + view: Specifies the dataset view to be used. Returns: - List page containing filtered and paginated dataset items. - """ - return await self._resource_client.list_items(**kwargs) - - async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `csv.writer`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - writer = csv.writer(destination, **kwargs) - writer.writerows([items[0].keys(), *[item.values() for item in items]]) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `json.dump`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - json.dump(items, destination, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: - """Export the entire dataset into a specified file stored under a key in a key-value store. - - This method consolidates all entries from a specified dataset into one file, which is then saved under a - given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. - Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or - name should be used. - - Args: - kwargs: Keyword arguments for the storage client method. + An object with filtered, sorted, and paginated dataset items plus pagination details. """ - key = cast('str', kwargs.get('key')) - content_type = kwargs.get('content_type', 'json') - to_key_value_store_id = kwargs.get('to_key_value_store_id') - to_key_value_store_name = kwargs.get('to_key_value_store_name') - - key_value_store = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - - output = io.StringIO() - if content_type == 'csv': - await self.write_to_csv(output) - elif content_type == 'json': - await self.write_to_json(output) - else: - raise ValueError('Unsupported content type, expecting CSV or JSON') - - if content_type == 'csv': - await key_value_store.set_value(key, output.getvalue(), 'text/csv') - - if content_type == 'json': - await key_value_store.set_value(key, output.getvalue(), 'application/json') - - async def get_info(self) -> DatasetMetadata | None: - """Get an object containing general information about the dataset.""" - metadata = await self._resource_client.get() - if isinstance(metadata, DatasetMetadata): - return metadata - return None + return await self._client.get_data( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) async def iterate_items( self, *, offset: int = 0, - limit: int | None = None, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -408,27 +222,29 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over dataset items, applying filtering, sorting, and pagination. + """Iterate over items in the dataset according to specified filters and sorting. - Retrieve dataset items incrementally, allowing fine-grained control over the data fetched. The function - supports various parameters to filter, sort, and limit the data returned, facilitating tailored dataset - queries. + This method allows for asynchronously iterating through dataset items while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - offset: Initial number of items to skip. - limit: Max number of items to return. No limit if None. - clean: Filter out empty items and hidden fields if True. - desc: Return items in reverse order if True. - fields: Specific fields to include in each item. - omit: Fields to omit from each item. - unwind: Field name to unwind items by. - skip_empty: Omits empty items if True. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. skip_hidden: Excludes fields starting with '#' if True. Yields: - Each item from the dataset as a dictionary. + An asynchronous iterator of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - async for item in self._resource_client.iterate_items( + async for item in self._client.iterate_items( offset=offset, limit=limit, clean=clean, @@ -441,59 +257,108 @@ async def iterate_items( ): yield item - @classmethod - async def check_and_serialize(cls, item: JsonSerializable, index: int | None = None) -> str: - """Serialize a given item to JSON, checks its serializability and size against a limit. + async def list_items( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> list[dict]: + """Retrieve a list of all items from the dataset according to specified filters and sorting. + + This method collects all dataset items into a list while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - item: The item to serialize. - index: Index of the item, used for error context. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. Returns: - Serialized JSON string. - - Raises: - ValueError: If item is not JSON serializable or exceeds size limit. + A list of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - s = ' ' if index is None else f' at index {index} ' - - try: - payload = await json_dumps(item) - except Exception as exc: - raise ValueError(f'Data item{s}is not serializable to JSON.') from exc - - payload_size = ByteSize(len(payload.encode('utf-8'))) - if payload_size > cls._EFFECTIVE_LIMIT_SIZE: - raise ValueError(f'Data item{s}is too large (size: {payload_size}, limit: {cls._EFFECTIVE_LIMIT_SIZE})') - - return payload - - async def _chunk_by_size(self, items: AsyncIterator[str]) -> AsyncIterator[str]: - """Yield chunks of JSON arrays composed of input strings, respecting a size limit. + return [ + item + async for item in self.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ) + ] + + @overload + async def export_to( + self, + key: str, + content_type: Literal['json'], + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataJsonKwargs], + ) -> None: ... + + @overload + async def export_to( + self, + key: str, + content_type: Literal['csv'], + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataCsvKwargs], + ) -> None: ... + + async def export_to( + self, + key: str, + content_type: Literal['json', 'csv'] = 'json', + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Any, + ) -> None: + """Export the entire dataset into a specified file stored under a key in a key-value store. - Groups an iterable of JSON string payloads into larger JSON arrays, ensuring the total size - of each array does not exceed `EFFECTIVE_LIMIT_SIZE`. Each output is a JSON array string that - contains as many payloads as possible without breaching the size threshold, maintaining the - order of the original payloads. Assumes individual items are below the size limit. + This method consolidates all entries from a specified dataset into one file, which is then saved under a + given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. + Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or + name should be used. Args: - items: Iterable of JSON string payloads. - - Yields: - Strings representing JSON arrays of payloads, each staying within the size limit. + key: The key under which to save the data in the key-value store. + content_type: The format in which to export the data. + to_key_value_store_id: ID of the key-value store to save the exported file. + Specify only one of ID or name. + to_key_value_store_name: Name of the key-value store to save the exported file. + Specify only one of ID or name. + kwargs: Additional parameters for the export operation, specific to the chosen content type. """ - last_chunk_size = ByteSize(2) # Add 2 bytes for [] wrapper. - current_chunk = [] - - async for payload in items: - payload_size = ByteSize(len(payload.encode('utf-8'))) - - if last_chunk_size + payload_size <= self._EFFECTIVE_LIMIT_SIZE: - current_chunk.append(payload) - last_chunk_size += payload_size + ByteSize(1) # Add 1 byte for ',' separator. - else: - yield f'[{",".join(current_chunk)}]' - current_chunk = [payload] - last_chunk_size = payload_size + ByteSize(2) # Add 2 bytes for [] wrapper. + kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) + dst = StringIO() - yield f'[{",".join(current_chunk)}]' + if content_type == 'csv': + await export_csv_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'text/csv') + elif content_type == 'json': + await export_json_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'application/json') + else: + raise ValueError('Unsupported content type, expecting CSV or JSON') diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index b7d3a4b582..41f9afe37e 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -1,115 +1,86 @@ from __future__ import annotations -import asyncio -from collections.abc import AsyncIterator -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import override from crawlee import service_locator from crawlee._utils.docs import docs_group -from crawlee.events._types import Event, EventPersistStateData -from crawlee.storage_clients.models import KeyValueStoreKeyInfo, KeyValueStoreMetadata, StorageMetadata from ._base import Storage if TYPE_CHECKING: from collections.abc import AsyncIterator - from crawlee._types import JsonSerializable from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import KeyValueStoreClient + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata T = TypeVar('T') @docs_group('Classes') class KeyValueStore(Storage): - """Represents a key-value based storage for reading and writing data records or files. + """Key-value store is a storage for reading and writing data records with unique key identifiers. - Each data record is identified by a unique key and associated with a specific MIME content type. This class is - commonly used in crawler runs to store inputs and outputs, typically in JSON format, but it also supports other - content types. + The key-value store class acts as a high-level interface for storing, retrieving, and managing data records + identified by unique string keys. It abstracts away the underlying storage implementation details, + allowing you to work with the same API regardless of whether data is stored in memory, on disk, + or in the cloud. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Each data record is associated with a specific MIME content type, allowing storage of various + data formats such as JSON, text, images, HTML snapshots or any binary data. This class is + commonly used to store inputs, outputs, and other artifacts of crawler operations. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY}.{EXT} - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{STORE_ID}`: The identifier for the key-value store, either "default" or as specified by - `CRAWLEE_DEFAULT_KEY_VALUE_STORE_ID`. - - `{KEY}`: The unique key for the record. - - `{EXT}`: The file extension corresponding to the MIME type of the content. - - To open a key-value store, use the `open` class method, providing an `id`, `name`, or optional `configuration`. - If none are specified, the default store for the current crawler run is used. Attempting to open a store by `id` - that does not exist will raise an error; however, if accessed by `name`, the store will be created if it does not - already exist. + You can instantiate a key-value store using the `open` class method, which will create a store + with the specified name or id. The underlying storage implementation is determined by the configured + storage client. ### Usage ```python from crawlee.storages import KeyValueStore - kvs = await KeyValueStore.open(name='my_kvs') + # Open a named key-value store + kvs = await KeyValueStore.open(name='my-store') + + # Store and retrieve data + await kvs.set_value('product-1234.json', [{'name': 'Smartphone', 'price': 799.99}]) + product = await kvs.get_value('product-1234') ``` """ - # Cache for persistent (auto-saved) values - _general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} - _persist_state_event_started = False + _cache_by_id: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores by their IDs.""" - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + _cache_by_name: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores by their names.""" - # Get resource clients from storage client - self._resource_client = storage_client.key_value_store(self._id) - self._autosave_lock = asyncio.Lock() + def __init__(self, client: KeyValueStoreClient) -> None: + """Initialize a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> KeyValueStore: - """Initialize a new instance of KeyValueStore from a storage metadata object.""" - key_value_store = KeyValueStore( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `KeyValueStore.open` constructor to create a new instance. - key_value_store.storage_object = storage_object - return key_value_store + Args: + client: An instance of a key-value store client. + """ + self._client = client - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.metadata.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object - - async def get_info(self) -> KeyValueStoreMetadata | None: - """Get an object containing general information about the key value store.""" - return await self._resource_client.get() + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._client.metadata @override @classmethod @@ -121,26 +92,43 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> KeyValueStore: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') + + # Check if key-value store is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - return await open_storage( - storage_class=cls, + client = await storage_client.open_key_value_store_client( id=id, name=name, configuration=configuration, - storage_client=storage_client, ) + kvs = cls(client) + + # Cache the key-value store by id and name if available + if kvs.id: + cls._cache_by_id[kvs.id] = kvs + if kvs.name: + cls._cache_by_name[kvs.name] = kvs + + return kvs + @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] - await self._resource_client.delete() - self._clear_cache() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + await self._client.drop() @overload async def get_value(self, key: str) -> Any: ... @@ -161,27 +149,9 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None: Returns: The value associated with the given key. `default_value` is used in case the record does not exist. """ - record = await self._resource_client.get_record(key) + record = await self._client.get_value(key=key) return record.value if record else default_value - async def iterate_keys(self, exclusive_start_key: str | None = None) -> AsyncIterator[KeyValueStoreKeyInfo]: - """Iterate over the existing keys in the KVS. - - Args: - exclusive_start_key: Key to start the iteration from. - - Yields: - Information about the key. - """ - while True: - list_keys = await self._resource_client.list_keys(exclusive_start_key=exclusive_start_key) - for item in list_keys.items: - yield KeyValueStoreKeyInfo(key=item.key, size=item.size) - - if not list_keys.is_truncated: - break - exclusive_start_key = list_keys.next_exclusive_start_key - async def set_value( self, key: str, @@ -192,91 +162,70 @@ async def set_value( Args: key: Key of the record to set. - value: Value to set. If `None`, the record is deleted. - content_type: Content type of the record. + value: Value to set. + content_type: The MIME content type string. """ - if value is None: - return await self._resource_client.delete_record(key) - - return await self._resource_client.set_record(key, value, content_type) + await self._client.set_value(key=key, value=value, content_type=content_type) - async def get_public_url(self, key: str) -> str: - """Get the public URL for the given key. + async def delete_value(self, key: str) -> None: + """Delete a value from the KVS. Args: - key: Key of the record for which URL is required. - - Returns: - The public URL for the given key. + key: Key of the record to delete. """ - return await self._resource_client.get_public_url(key) + await self._client.delete_value(key=key) - async def get_auto_saved_value( + async def iterate_keys( self, - key: str, - default_value: dict[str, JsonSerializable] | None = None, - ) -> dict[str, JsonSerializable]: - """Get a value from KVS that will be automatically saved on changes. + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over the existing keys in the KVS. Args: - key: Key of the record, to store the value. - default_value: Value to be used if the record does not exist yet. Should be a dictionary. + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. None means no limit. - Returns: - Return the value of the key. + Yields: + Information about the key. """ - default_value = {} if default_value is None else default_value + async for item in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ): + yield item - async with self._autosave_lock: - if key in self._cache: - return self._cache[key] + async def list_keys( + self, + exclusive_start_key: str | None = None, + limit: int = 1000, + ) -> list[KeyValueStoreRecordMetadata]: + """List all the existing keys in the KVS. - value = await self.get_value(key, default_value) + It uses client's `iterate_keys` method to get the keys. - if not isinstance(value, dict): - raise TypeError( - f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}' - ) + Args: + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. - self._cache[key] = value + Returns: + A list of keys in the KVS. + """ + return [ + key + async for key in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ) + ] - self._ensure_persist_event() + async def get_public_url(self, key: str) -> str: + """Get the public URL for the given key. - return value + Args: + key: Key of the record for which URL is required. - @property - def _cache(self) -> dict[str, dict[str, JsonSerializable]]: - """Cache dictionary for storing auto-saved values indexed by store ID.""" - if self._id not in self._general_cache: - self._general_cache[self._id] = {} - return self._general_cache[self._id] - - async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None: - """Save cache with persistent values. Can be used in Event Manager.""" - for key, value in self._cache.items(): - await self.set_value(key, value) - - def _ensure_persist_event(self) -> None: - """Ensure persist state event handling if not already done.""" - if self._persist_state_event_started: - return - - event_manager = service_locator.get_event_manager() - event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save) - self._persist_state_event_started = True - - def _clear_cache(self) -> None: - """Clear cache with persistent values.""" - self._cache.clear() - - def _drop_persist_state_event(self) -> None: - """Off event_manager listener and drop event status.""" - if self._persist_state_event_started: - event_manager = service_locator.get_event_manager() - event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save) - self._persist_state_event_started = False - - async def persist_autosaved_values(self) -> None: - """Force persistent values to be saved without waiting for an event in Event Manager.""" - if self._persist_state_event_started: - await self._persist_save() + Returns: + The public URL for the given key. + """ + return await self._client.get_public_url(key=key) diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index b3274ccc81..d998cafe46 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -1,23 +1,16 @@ from __future__ import annotations import asyncio -from collections import deque -from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from logging import getLogger -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, ClassVar, TypeVar -from cachetools import LRUCache from typing_extensions import override -from crawlee import service_locator -from crawlee._utils.crypto import crypto_random_object_id +from crawlee import Request, service_locator from crawlee._utils.docs import docs_group -from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events import Event from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata, StorageMetadata from ._base import Storage @@ -27,131 +20,92 @@ from crawlee import Request from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import RequestQueueClient + from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata logger = getLogger(__name__) T = TypeVar('T') -class CachedRequest(TypedDict): - id: str - was_already_handled: bool - hydrated: Request | None - lock_expires_at: datetime | None - forefront: bool - - @docs_group('Classes') class RequestQueue(Storage, RequestManager): - """Represents a queue storage for managing HTTP requests in web crawling operations. + """Request queue is a storage for managing HTTP requests. - The `RequestQueue` class handles a queue of HTTP requests, each identified by a unique URL, to facilitate structured - web crawling. It supports both breadth-first and depth-first crawling strategies, allowing for recursive crawling - starting from an initial set of URLs. Each URL in the queue is uniquely identified by a `unique_key`, which can be - customized to allow the same URL to be added multiple times under different keys. + The request queue class serves as a high-level interface for organizing and managing HTTP requests + during web crawling. It provides methods for adding, retrieving, and manipulating requests throughout + the crawling lifecycle, abstracting away the underlying storage implementation details. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Request queue maintains the state of each URL to be crawled, tracking whether it has been processed, + is currently being handled, or is waiting in the queue. Each URL in the queue is uniquely identified + by a `unique_key` property, which prevents duplicate processing unless explicitly configured otherwise. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{QUEUE_ID}`: The identifier for the request queue, either "default" or as specified. - - `{REQUEST_ID}`: The unique identifier for each request in the queue. + The class supports both breadth-first and depth-first crawling strategies through its `forefront` parameter + when adding requests. It also provides mechanisms for error handling and request reclamation when + processing fails. - The `RequestQueue` supports both creating new queues and opening existing ones by `id` or `name`. Named queues - persist indefinitely, while unnamed queues expire after 7 days unless specified otherwise. The queue supports - mutable operations, allowing URLs to be added and removed as needed. + You can open a request queue using the `open` class method, specifying either a name or ID to identify + the queue. The underlying storage implementation is determined by the configured storage client. ### Usage ```python from crawlee.storages import RequestQueue - rq = await RequestQueue.open(name='my_rq') + # Open a request queue + rq = await RequestQueue.open(name='my_queue') + + # Add a request + await rq.add_request('https://example.com') + + # Process requests + request = await rq.fetch_next_request() + if request: + try: + # Process the request + # ... + await rq.mark_request_as_handled(request) + except Exception: + await rq.reclaim_request(request) ``` """ - _MAX_CACHED_REQUESTS = 1_000_000 - """Maximum number of requests that can be cached.""" + _cache_by_id: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues by their IDs.""" - def __init__( - self, - id: str, - name: str | None, - storage_client: StorageClient, - ) -> None: - config = service_locator.get_configuration() - event_manager = service_locator.get_event_manager() + _cache_by_name: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues by their names.""" - self._id = id - self._name = name + _MAX_CACHED_REQUESTS = 1_000_000 + """Maximum number of requests that can be cached.""" - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + def __init__(self, client: RequestQueueClient) -> None: + """Initialize a new instance. - # Get resource clients from storage client - self._resource_client = storage_client.request_queue(self._id) - self._resource_collection_client = storage_client.request_queues() - - self._request_lock_time = timedelta(minutes=3) - self._queue_paused_for_migration = False - self._queue_has_locked_requests: bool | None = None - self._should_check_for_forefront_requests = False - - self._is_finished_log_throttle_counter = 0 - self._dequeued_request_count = 0 - - event_manager.on(event=Event.MIGRATING, listener=lambda _: setattr(self, '_queue_paused_for_migration', True)) - event_manager.on(event=Event.MIGRATING, listener=self._clear_possible_locks) - event_manager.on(event=Event.ABORTING, listener=self._clear_possible_locks) - - # Other internal attributes - self._tasks = list[asyncio.Task]() - self._client_key = crypto_random_object_id() - self._internal_timeout = config.internal_timeout or timedelta(minutes=5) - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._queue_head = deque[str]() - self._list_head_and_lock_task: asyncio.Task | None = None - self._last_activity = datetime.now(timezone.utc) - self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) + Preferably use the `RequestQueue.open` constructor to create a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> RequestQueue: - """Initialize a new instance of RequestQueue from a storage metadata object.""" - request_queue = RequestQueue( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Args: + client: An instance of a request queue client. + """ + self._client = client - request_queue.storage_object = storage_object - return request_queue + self._add_requests_tasks = list[asyncio.Task]() + """A list of tasks for adding requests to the queue.""" - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.metadata.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + @property + def metadata(self) -> RequestQueueMetadata: + return self._client.metadata @override @classmethod @@ -163,29 +117,43 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> RequestQueue: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + # Check if request queue is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] - return await open_storage( - storage_class=cls, + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client + + client = await storage_client.open_request_queue_client( id=id, name=name, configuration=configuration, - storage_client=storage_client, ) - @override - async def drop(self, *, timeout: timedelta | None = None) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + rq = cls(client) - # Wait for all tasks to finish - await wait_for_all_tasks_for_finish(self._tasks, logger=logger, timeout=timeout) + # Cache the request queue by id and name if available + if rq.id: + cls._cache_by_id[rq.id] = rq + if rq.name: + cls._cache_by_name[rq.name] = rq - # Delete the storage from the underlying client and remove it from the cache - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + return rq + + @override + async def drop(self) -> None: + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] + + await self._client.drop() @override async def add_request( @@ -195,40 +163,15 @@ async def add_request( forefront: bool = False, ) -> ProcessedRequest: request = self._transform_request(request) - self._last_activity = datetime.now(timezone.utc) - - cache_key = unique_key_to_request_id(request.unique_key) - cached_info = self._requests_cache.get(cache_key) - - if cached_info: - request.id = cached_info['id'] - # We may assume that if request is in local cache then also the information if the request was already - # handled is there because just one client should be using one queue. - return ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=cached_info['was_already_handled'], - ) - - processed_request = await self._resource_client.add_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - - self._cache_request(cache_key, processed_request, forefront=forefront) - - if not processed_request.was_already_present and forefront: - self._should_check_for_forefront_requests = True - - if request.handled_at is None and not processed_request.was_already_present: - self._assumed_total_count += 1 - - return processed_request + response = await self._client.add_batch_of_requests([request], forefront=forefront) + return response.processed_requests[0] @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, @@ -240,21 +183,31 @@ async def add_requests_batched( # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] if first_batch: - await self._process_batch(first_batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + first_batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) async def _process_remaining_batches() -> None: for i in range(batch_size, len(transformed_requests), batch_size): batch = transformed_requests[i : i + batch_size] - await self._process_batch(batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) if i + batch_size < len(transformed_requests): await asyncio.sleep(wait_time_secs) # Create and start the task to process remaining batches in the background remaining_batches_task = asyncio.create_task( - _process_remaining_batches(), name='request_queue_process_remaining_batches_task' + _process_remaining_batches(), + name='request_queue_process_remaining_batches_task', ) - self._tasks.append(remaining_batches_task) - remaining_batches_task.add_done_callback(lambda _: self._tasks.remove(remaining_batches_task)) + + self._add_requests_tasks.append(remaining_batches_task) + remaining_batches_task.add_done_callback(lambda _: self._add_requests_tasks.remove(remaining_batches_task)) # Wait for all tasks to finish if requested if wait_for_all_requests_to_be_added: @@ -264,42 +217,6 @@ async def _process_remaining_batches() -> None: timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _process_batch(self, batch: Sequence[Request], base_retry_wait: timedelta, attempt: int = 1) -> None: - max_attempts = 5 - response = await self._resource_client.batch_add_requests(batch) - - if response.unprocessed_requests: - logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') - if attempt > max_attempts: - logger.warning( - f'Following requests were not processed even after {max_attempts} attempts:\n' - f'{response.unprocessed_requests}' - ) - else: - logger.debug('Retry to add requests.') - unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} - retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] - await asyncio.sleep((base_retry_wait * attempt).total_seconds()) - await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - - request_count = len(batch) - len(response.unprocessed_requests) - self._assumed_total_count += request_count - if request_count: - logger.debug( - f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' - ) - - async def get_request(self, request_id: str) -> Request | None: - """Retrieve a request from the queue. - - Args: - request_id: ID of the request to retrieve. - - Returns: - The retrieved request, or `None`, if it does not exist. - """ - return await self._resource_client.get_request(request_id) - async def fetch_next_request(self) -> Request | None: """Return the next request in the queue to be processed. @@ -313,75 +230,35 @@ async def fetch_next_request(self) -> Request | None: instead. Returns: - The request or `None` if there are no more pending requests. + The next request to process, or `None` if there are no more pending requests. """ - self._last_activity = datetime.now(timezone.utc) - - await self._ensure_head_is_non_empty() - - # We are likely done at this point. - if len(self._queue_head) == 0: - return None + return await self._client.fetch_next_request() - next_request_id = self._queue_head.popleft() - request = await self._get_or_hydrate_request(next_request_id) - - # NOTE: It can happen that the queue head index is inconsistent with the main queue table. - # This can occur in two situations: + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a specific request from the queue by its ID. - # 1) - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). In this case, keep the request marked as in progress for a short while, - # so that is_finished() doesn't return true and _ensure_head_is_non_empty() doesn't not load the request into - # the queueHeadDict straight again. After the interval expires, fetch_next_request() will try to fetch this - # request again, until it eventually appears in the main table. - if request is None: - logger.debug( - 'Cannot find a request from the beginning of queue, will be retried later', - extra={'nextRequestId': next_request_id}, - ) - return None - - # 2) - # Queue head index is behind the main table and the underlying request was already handled (by some other - # client, since we keep the track of handled requests in recently_handled dictionary). We just add the request - # to the recently_handled dictionary so that next call to _ensure_head_is_non_empty() will not put the request - # again to queue_head_dict. - if request.handled_at is not None: - logger.debug( - 'Request fetched from the beginning of queue was already handled', - extra={'nextRequestId': next_request_id}, - ) - return None + Args: + request_id: The ID of the request to retrieve. - self._dequeued_request_count += 1 - return request + Returns: + The request with the specified ID, or `None` if no such request exists. + """ + return await self._client.get_request(request_id) async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. - Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. + This method should be called after a request has been successfully processed. + Once marked as handled, the request will be removed from the queue and will + not be returned in subsequent calls to `fetch_next_request` method. Args: request: The request to mark as handled. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - if request.handled_at is None: - request.handled_at = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request) - processed_request.unique_key = request.unique_key - self._dequeued_request_count -= 1 - - if not processed_request.was_already_handled: - self._assumed_handled_count += 1 - - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=False) - return processed_request + return await self._client.mark_request_as_handled(request) async def reclaim_request( self, @@ -389,325 +266,83 @@ async def reclaim_request( *, forefront: bool = False, ) -> ProcessedRequest | None: - """Reclaim a failed request back to the queue. + """Reclaim a failed request back to the queue for later processing. - The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. + If a request fails during processing, this method can be used to return it to the queue. + The request will be returned for processing again in a subsequent call + to `RequestQueue.fetch_next_request`. Args: request: The request to return to the queue. - forefront: Whether to add the request to the head or the end of the queue. + forefront: If true, the request will be added to the beginning of the queue. + Otherwise, it will be added to the end. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=forefront) - - if forefront: - self._should_check_for_forefront_requests = True - - if processed_request: - # Try to delete the request lock if possible - try: - await self._resource_client.delete_request_lock(request.id, forefront=forefront) - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - - return processed_request + return await self._client.reclaim_request(request, forefront=forefront) async def is_empty(self) -> bool: - """Check whether the queue is empty. + """Check if the request queue is empty. + + An empty queue means that there are no requests currently in the queue, either pending or being processed. + However, this does not necessarily mean that the crawling operation is finished, as there still might be + tasks that could add additional requests to the queue. Returns: - bool: `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. + True if the request queue is empty, False otherwise. """ - await self._ensure_head_is_non_empty() - return len(self._queue_head) == 0 + return await self._client.is_empty() async def is_finished(self) -> bool: - """Check whether the queue is finished. + """Check if the request queue is finished. - Due to the nature of distributed storage used by the queue, the function might occasionally return a false - negative, but it will never return a false positive. + A finished queue means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. This is the definitive way to check + if a crawling operation is complete. Returns: - bool: `True` if all requests were already handled and there are no more left. `False` otherwise. + True if the request queue is finished (empty and no pending add operations), False otherwise. """ - if self._tasks: - logger.debug('Background tasks are still in progress') - return False - - if self._queue_head: - logger.debug( - 'There are still ids in the queue head that are pending processing', - extra={ - 'queue_head_ids_pending': len(self._queue_head), - }, - ) - - return False - - await self._ensure_head_is_non_empty() - - if self._queue_head: - logger.debug('Queue head still returned requests that need to be processed') - + if self._add_requests_tasks: + logger.debug('Background add requests tasks are still in progress.') return False - # Could not lock any new requests - decide based on whether the queue contains requests locked by another client - if self._queue_has_locked_requests is not None: - if self._queue_has_locked_requests and self._dequeued_request_count == 0: - # The `% 25` was absolutely arbitrarily picked. It's just to not spam the logs too much. - if self._is_finished_log_throttle_counter % 25 == 0: - logger.info('The queue still contains requests locked by another client') - - self._is_finished_log_throttle_counter += 1 - - logger.debug( - f'Deciding if we are finished based on `queue_has_locked_requests` = {self._queue_has_locked_requests}' - ) - return not self._queue_has_locked_requests - - metadata = await self._resource_client.get() - if metadata is not None and not metadata.had_multiple_clients and not self._queue_head: - logger.debug('Queue head is empty and there are no other clients - we are finished') - + if await self.is_empty(): + logger.debug('The request queue is empty.') return True - # The following is a legacy algorithm for checking if the queue is finished. - # It is used only for request queue clients that do not provide the `queue_has_locked_requests` flag. - current_head = await self._resource_client.list_head(limit=2) - - if current_head.items: - logger.debug('The queue still contains unfinished requests or requests locked by another client') + return False - return len(current_head.items) == 0 - - async def get_info(self) -> RequestQueueMetadata | None: - """Get an object containing general information about the request queue.""" - return await self._resource_client.get() - - @override - async def get_handled_count(self) -> int: - return self._assumed_handled_count - - @override - async def get_total_count(self) -> int: - return self._assumed_total_count - - async def _ensure_head_is_non_empty(self) -> None: - # Stop fetching if we are paused for migration - if self._queue_paused_for_migration: - return - - # We want to fetch ahead of time to minimize dead time - if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: - return - - if self._list_head_and_lock_task is None: - task = asyncio.create_task(self._list_head_and_lock(), name='request_queue_list_head_and_lock_task') - - def callback(_: Any) -> None: - self._list_head_and_lock_task = None - - task.add_done_callback(callback) - self._list_head_and_lock_task = task - - await self._list_head_and_lock_task - - async def _list_head_and_lock(self) -> None: - # Make a copy so that we can clear the flag only if the whole method executes after the flag was set - # (i.e, it was not set in the middle of the execution of the method) - should_check_for_forefront_requests = self._should_check_for_forefront_requests - - limit = 25 - - response = await self._resource_client.list_and_lock_head( - limit=limit, lock_secs=int(self._request_lock_time.total_seconds()) - ) - - self._queue_has_locked_requests = response.queue_has_locked_requests - - head_id_buffer = list[str]() - forefront_head_id_buffer = list[str]() + async def _process_batch( + self, + batch: Sequence[Request], + *, + base_retry_wait: timedelta, + attempt: int = 1, + forefront: bool = False, + ) -> None: + """Process a batch of requests with automatic retry mechanism.""" + max_attempts = 5 + response = await self._client.add_batch_of_requests(batch, forefront=forefront) - for request in response.items: - # Queue head index might be behind the main table, so ensure we don't recycle requests - if not request.id or not request.unique_key: - logger.debug( - 'Skipping request from queue head, already in progress or recently handled', - extra={ - 'id': request.id, - 'unique_key': request.unique_key, - }, + if response.unprocessed_requests: + logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') + if attempt > max_attempts: + logger.warning( + f'Following requests were not processed even after {max_attempts} attempts:\n' + f'{response.unprocessed_requests}' ) - - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request.id) - - continue - - # If we remember that we added the request ourselves and we added it to the forefront, - # we will put it to the beginning of the local queue head to preserve the expected order. - # If we do not remember that, we will enqueue it normally. - cached_request = self._requests_cache.get(unique_key_to_request_id(request.unique_key)) - forefront = cached_request['forefront'] if cached_request else False - - if forefront: - forefront_head_id_buffer.insert(0, request.id) else: - head_id_buffer.append(request.id) - - self._cache_request( - unique_key_to_request_id(request.unique_key), - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ), - forefront=forefront, - ) - - for request_id in head_id_buffer: - self._queue_head.append(request_id) - - for request_id in forefront_head_id_buffer: - self._queue_head.appendleft(request_id) - - # If the queue head became too big, unlock the excess requests - to_unlock = list[str]() - while len(self._queue_head) > limit: - to_unlock.append(self._queue_head.pop()) - - if to_unlock: - await asyncio.gather( - *[self._resource_client.delete_request_lock(request_id) for request_id in to_unlock], - return_exceptions=True, # Just ignore the exceptions - ) - - # Unset the should_check_for_forefront_requests flag - the check is finished - if should_check_for_forefront_requests: - self._should_check_for_forefront_requests = False - - def _reset(self) -> None: - self._queue_head.clear() - self._list_head_and_lock_task = None - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._requests_cache.clear() - self._last_activity = datetime.now(timezone.utc) - - def _cache_request(self, cache_key: str, processed_request: ProcessedRequest, *, forefront: bool) -> None: - self._requests_cache[cache_key] = { - 'id': processed_request.id, - 'was_already_handled': processed_request.was_already_handled, - 'hydrated': None, - 'lock_expires_at': None, - 'forefront': forefront, - } - - async def _get_or_hydrate_request(self, request_id: str) -> Request | None: - cached_entry = self._requests_cache.get(request_id) - - if not cached_entry: - # 2.1. Attempt to prolong the request lock to see if we still own the request - prolong_result = await self._prolong_request_lock(request_id) - - if not prolong_result: - return None - - # 2.1.1. If successful, hydrate the request and return it - hydrated_request = await self.get_request(request_id) - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - - return None - - self._requests_cache[request_id] = { - 'id': request_id, - 'hydrated': hydrated_request, - 'was_already_handled': hydrated_request.handled_at is not None, - 'lock_expires_at': prolong_result, - 'forefront': False, - } - - return hydrated_request - - # 1.1. If hydrated, prolong the lock more and return it - if cached_entry['hydrated']: - # 1.1.1. If the lock expired on the hydrated requests, try to prolong. If we fail, we lost the request - # (or it was handled already) - if cached_entry['lock_expires_at'] and cached_entry['lock_expires_at'] < datetime.now(timezone.utc): - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - cached_entry['lock_expires_at'] = prolonged - - return cached_entry['hydrated'] - - # 1.2. If not hydrated, try to prolong the lock first (to ensure we keep it in our queue), hydrate and return it - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - # This might still return null if the queue head is inconsistent with the main queue table. - hydrated_request = await self.get_request(cached_entry['id']) - - cached_entry['hydrated'] = hydrated_request - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(cached_entry['id']) - - return None + logger.debug('Retry to add requests.') + unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} + retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] + await asyncio.sleep((base_retry_wait * attempt).total_seconds()) + await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - return hydrated_request + request_count = len(batch) - len(response.unprocessed_requests) - async def _prolong_request_lock(self, request_id: str) -> datetime | None: - try: - res = await self._resource_client.prolong_request_lock( - request_id, lock_secs=int(self._request_lock_time.total_seconds()) - ) - except Exception as err: - # Most likely we do not own the lock anymore - logger.warning( - f'Failed to prolong lock for cached request {request_id}, either lost the lock ' - 'or the request was already handled\n', - exc_info=err, + if request_count: + logger.debug( + f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' ) - return None - else: - return res.lock_expires_at - - async def _clear_possible_locks(self) -> None: - self._queue_paused_for_migration = True - request_id: str | None = None - - while True: - try: - request_id = self._queue_head.pop() - except LookupError: - break - - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - # If this fails, we don't have the lock, or the request was never locked. Either way it's fine diff --git a/src/crawlee/storages/_types.py b/src/crawlee/storages/_types.py new file mode 100644 index 0000000000..bc99ce72fd --- /dev/null +++ b/src/crawlee/storages/_types.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict + +if TYPE_CHECKING: + import json + from collections.abc import Callable + from datetime import datetime + + from typing_extensions import NotRequired, Required + + from crawlee import Request + + +class CachedRequest(TypedDict): + """Represent a cached request in the `RequestQueue`.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Indicates whether the request was already handled.""" + + hydrated: Request | None + """The hydrated request object.""" + + lock_expires_at: datetime | None + """The time at which the lock on the request expires.""" + + forefront: bool + """Indicates whether the request is at the forefront of the queue.""" + + +class IterateKwargs(TypedDict): + """Keyword arguments for dataset's `iterate` method.""" + + offset: NotRequired[int] + """Skips the specified number of items at the start.""" + + limit: NotRequired[int | None] + """The maximum number of items to retrieve. Unlimited if None.""" + + clean: NotRequired[bool] + """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" + + desc: NotRequired[bool] + """Set to True to sort results in descending order.""" + + fields: NotRequired[list[str]] + """Fields to include in each item. Sorts fields as specified if provided.""" + + omit: NotRequired[list[str]] + """Fields to exclude from each item.""" + + unwind: NotRequired[str] + """Unwinds items by a specified array field, turning each element into a separate item.""" + + skip_empty: NotRequired[bool] + """Excludes empty items from the results if True.""" + + skip_hidden: NotRequired[bool] + """Excludes fields starting with '#' if True.""" + + +class GetDataKwargs(IterateKwargs): + """Keyword arguments for dataset's `get_data` method.""" + + flatten: NotRequired[list[str]] + """Fields to be flattened in returned items.""" + + view: NotRequired[str] + """Specifies the dataset view to be used.""" + + +class ExportToKwargs(TypedDict): + """Keyword arguments for dataset's `export_to` method.""" + + key: Required[str] + """The key under which to save the data.""" + + content_type: NotRequired[Literal['json', 'csv']] + """The format in which to export the data. Either 'json' or 'csv'.""" + + to_key_value_store_id: NotRequired[str] + """ID of the key-value store to save the exported file.""" + + to_key_value_store_name: NotRequired[str] + """Name of the key-value store to save the exported file.""" + + +class ExportDataJsonKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_json` method.""" + + skipkeys: NotRequired[bool] + """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped + instead of raising a `TypeError`.""" + + ensure_ascii: NotRequired[bool] + """Determines if non-ASCII characters should be escaped in the output JSON string.""" + + check_circular: NotRequired[bool] + """If False (default: True), skips the circular reference check for container types. A circular reference will + result in a `RecursionError` or worse if unchecked.""" + + allow_nan: NotRequired[bool] + """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply + with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" + + cls: NotRequired[type[json.JSONEncoder]] + """Allows specifying a custom JSON encoder.""" + + indent: NotRequired[int] + """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" + + separators: NotRequired[tuple[str, str]] + """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') + otherwise.""" + + default: NotRequired[Callable] + """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version + of the object or raise a `TypeError`.""" + + sort_keys: NotRequired[bool] + """Specifies whether the output JSON object should have keys sorted alphabetically.""" + + +class ExportDataCsvKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_csv` method.""" + + dialect: NotRequired[str] + """Specifies a dialect to be used in CSV parsing and writing.""" + + delimiter: NotRequired[str] + """A one-character string used to separate fields. Defaults to ','.""" + + doublequote: NotRequired[bool] + """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; + when False, the `escapechar` is used as a prefix. Defaults to True.""" + + escapechar: NotRequired[str] + """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` + if `doublequote` is False. Defaults to None, disabling escaping.""" + + lineterminator: NotRequired[str] + """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" + + quotechar: NotRequired[str] + """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, + or fields containing new-line characters. Defaults to '\"'.""" + + quoting: NotRequired[int] + """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of + the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" + + skipinitialspace: NotRequired[bool] + """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" + + strict: NotRequired[bool] + """When True, raises an exception on bad CSV input. Defaults to False.""" diff --git a/tests/e2e/project_template/utils.py b/tests/e2e/project_template/utils.py index 3bc5be4ea6..685e8c45e8 100644 --- a/tests/e2e/project_template/utils.py +++ b/tests/e2e/project_template/utils.py @@ -20,23 +20,25 @@ def patch_crawlee_version_in_project( def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, wheel_path: Path) -> None: # Get any extras - with open(project_path / 'requirements.txt') as f: + requirements_path = project_path / 'requirements.txt' + with requirements_path.open() as f: requirements = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', requirements)[0] or '' # Modify requirements.txt to use crawlee from wheel file instead of from Pypi - with open(project_path / 'requirements.txt') as f: + with requirements_path.open() as f: modified_lines = [] for line in f: if 'crawlee' in line: modified_lines.append(f'./{wheel_path.name}{crawlee_extras}\n') else: modified_lines.append(line) - with open(project_path / 'requirements.txt', 'w') as f: + with requirements_path.open('w') as f: f.write(''.join(modified_lines)) # Patch the dockerfile to have wheel file available - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -49,19 +51,21 @@ def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, f'RUN pip install ./{wheel_path.name}{crawlee_extras} --force-reinstall\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, wheel_path: Path) -> None: """Ensure that the test is using current version of the crawlee from the source and not from Pypi.""" # Get any extras - with open(project_path / 'pyproject.toml') as f: + pyproject_path = project_path / 'pyproject.toml' + with pyproject_path.open() as f: pyproject = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', pyproject)[0] or '' # Inject crawlee wheel file to the docker image and update project to depend on it.""" - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -94,5 +98,5 @@ def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, w f'RUN {package_manager} lock\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) diff --git a/tests/unit/_utils/test_file.py b/tests/unit/_utils/test_file.py index a86291b43f..0762e1d966 100644 --- a/tests/unit/_utils/test_file.py +++ b/tests/unit/_utils/test_file.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io from datetime import datetime, timezone from pathlib import Path @@ -12,7 +11,6 @@ force_remove, force_rename, is_content_type, - is_file_or_bytes, json_dumps, ) @@ -25,15 +23,6 @@ async def test_json_dumps() -> None: assert await json_dumps(datetime(2022, 1, 1, tzinfo=timezone.utc)) == '"2022-01-01 00:00:00+00:00"' -def test_is_file_or_bytes() -> None: - assert is_file_or_bytes(b'bytes') is True - assert is_file_or_bytes(bytearray(b'bytearray')) is True - assert is_file_or_bytes(io.BytesIO(b'some bytes')) is True - assert is_file_or_bytes(io.StringIO('string')) is True - assert is_file_or_bytes('just a regular string') is False - assert is_file_or_bytes(12345) is False - - @pytest.mark.parametrize( ('content_type_enum', 'content_type', 'expected_result'), [ @@ -115,7 +104,7 @@ async def test_force_remove(tmp_path: Path) -> None: assert test_file_path.exists() is False # Remove the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): # noqa: ASYNC230 + with test_file_path.open('a', encoding='utf-8'): pass assert test_file_path.exists() is True await force_remove(test_file_path) @@ -134,11 +123,11 @@ async def test_force_rename(tmp_path: Path) -> None: # Will remove dst_dir if it exists (also covers normal case) # Create the src_dir with a file in it src_dir.mkdir() - with open(src_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with src_file.open('a', encoding='utf-8'): pass # Create the dst_dir with a file in it dst_dir.mkdir() - with open(dst_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with dst_file.open('a', encoding='utf-8'): pass assert src_file.exists() is True assert dst_file.exists() is True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b7ac06d124..1b73df5743 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,12 +13,10 @@ from uvicorn.config import Config from crawlee import service_locator -from crawlee.configuration import Configuration from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient from crawlee.proxy_configuration import ProxyInfo -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storages import KeyValueStore, _creation_management +from crawlee.storages import KeyValueStore from tests.unit.server import TestServer, app, serve_in_thread if TYPE_CHECKING: @@ -64,14 +62,6 @@ def _prepare_test_env() -> None: service_locator._event_manager = None service_locator._storage_client = None - # Clear creation-related caches to ensure no state is carried over between tests. - monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_dataset_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_name', {}) - # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) assert service_locator._configuration_was_retrieved is False @@ -149,18 +139,6 @@ async def disabled_proxy(proxy_info: ProxyInfo) -> AsyncGenerator[ProxyInfo, Non yield proxy_info -@pytest.fixture -def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: - """A fixture for testing the memory storage client and its resource clients.""" - config = Configuration( - persist_storage=True, - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - - return MemoryStorageClient.from_config(config) - - @pytest.fixture(scope='session') def header_network() -> dict: return get_available_header_network() diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 40f57de5ea..5cc572c16a 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -41,7 +41,7 @@ async def test_processes_requests_from_explicit_queue() -> None: queue = await RequestQueue.open() - await queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) crawler = BasicCrawler(request_manager=queue) calls = list[str]() @@ -57,7 +57,7 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_processes_requests_from_request_source_tandem() -> None: request_queue = await RequestQueue.open() - await request_queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await request_queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) request_list = RequestList(['http://a.com/', 'http://d.com', 'http://e.com']) diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py new file mode 100644 index 0000000000..e832c1f4c1 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemDatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemStorageClient().open_dataset_client( + name='test_dataset', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_dataset(configuration: Configuration) -> None: + """Test that open() creates a new dataset with proper metadata when it doesn't exist.""" + client = await FileSystemStorageClient().open_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemDatasetClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_dataset.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_dataset' + assert metadata['item_count'] == 0 + + +async def test_open_existing_dataset( + dataset_client: FileSystemDatasetClient, + configuration: Configuration, +) -> None: + """Test that open() loads an existing dataset correctly.""" + configuration.purge_on_start = False + + # Open the same dataset again + reopened_client = await FileSystemStorageClient().open_dataset_client( + name=dataset_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert dataset_client.metadata.id == reopened_client.metadata.id + assert dataset_client.metadata.name == reopened_client.metadata.name + assert dataset_client.metadata.item_count == reopened_client.metadata.item_count + + # Verify clients (python) ids + assert id(dataset_client) == id(reopened_client) + + +async def test_dataset_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration.purge_on_start = True + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + configuration.purge_on_start = False + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_dataset_client(id='some-id', configuration=configuration) + + +async def test_push_data_single_item(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 1 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 2 # 1 data file + 1 metadata file + + # Verify item was persisted + data_files = [item for item in all_files if item.name != METADATA_FILENAME] + assert len(data_files) == 1 + + # Verify file content + with Path(data_files[0]).open() as f: + saved_item = json.load(f) + assert saved_item == item + + +async def test_push_data_multiple_items(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing multiple items to the dataset.""" + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 3 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 4 # 3 data files + 1 metadata file + + # Verify items were saved to files + data_files = [f for f in all_files if f.name != METADATA_FILENAME] + assert len(data_files) == 3 + + +async def test_get_data_empty_dataset(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from an empty dataset returns empty list.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from a dataset returns all items in order with correct properties.""" + # Add some items + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with offset and limit parameters for pagination implementation.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with skip_empty option filters out empty items when True.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating over dataset items yields each item in the original order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating with offset, limit and desc parameters works the same as with get_data().""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(dataset_client: FileSystemDatasetClient) -> None: + """Test dropping a dataset removes the entire dataset directory from disk.""" + await dataset_client.push_data({'test': 'data'}) + + assert dataset_client.metadata.name in FileSystemDatasetClient._cache_by_name + assert dataset_client.path_to_dataset.exists() + + # Drop the dataset + await dataset_client.drop() + + assert dataset_client.metadata.name not in FileSystemDatasetClient._cache_by_name + assert not dataset_client.path_to_dataset.exists() + + +async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: + """Test that metadata timestamps are updated correctly after read and write operations.""" + # Record initial timestamps + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified + + accessed_after_get = dataset_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py new file mode 100644 index 0000000000..95ae2aa929 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemStorageClient().open_key_value_store_client( + name='test_kvs', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_kvs(configuration: Configuration) -> None: + """Test that open() creates a new key-value store with proper metadata and files on disk.""" + client = await FileSystemStorageClient().open_key_value_store_client( + name='new_kvs', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemKeyValueStoreClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_kvs.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_kvs' + + +async def test_open_existing_kvs( + kvs_client: FileSystemKeyValueStoreClient, + configuration: Configuration, +) -> None: + """Test that open() loads an existing key-value store with matching properties.""" + configuration.purge_on_start = False + + # Open the same key-value store again + reopened_client = await FileSystemStorageClient().open_key_value_store_client( + name=kvs_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert kvs_client.metadata.id == reopened_client.metadata.id + assert kvs_client.metadata.name == reopened_client.metadata.name + + # Verify clients (python) ids - should be the same object due to caching + assert id(kvs_client) == id(reopened_client) + + +async def test_kvs_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the key-value store.""" + configuration.purge_on_start = True + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the key-value store.""" + configuration.purge_on_start = False + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided (unsupported for file system client).""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_key_value_store_client(id='some-id', configuration=configuration) + + +async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a string value with correct file creation and metadata.""" + # Set a value + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if the file was created + key_path = kvs_client.path_to_kvs / test_key + key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert key_metadata_path.exists() + + # Check file content + content = key_path.read_text(encoding='utf-8') + assert content == test_value + + # Check record metadata + with key_metadata_path.open() as f: + metadata = json.load(f) + assert metadata['key'] == test_key + assert metadata['content_type'] == 'text/plain; charset=utf-8' + assert metadata['size'] == len(test_value.encode('utf-8')) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'text/plain; charset=utf-8' + assert record.size == len(test_value.encode('utf-8')) + + +async def test_set_get_value_json(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a JSON value with correct serialization and deserialization.""" + # Set a value + test_key = 'test-json' + test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert 'application/json' in record.content_type + + +async def test_set_get_value_bytes(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting binary data without corruption and with correct content type.""" + # Set a value + test_key = 'test-binary' + test_value = b'\x00\x01\x02\x03\x04' + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'application/octet-stream' + assert record.size == len(test_value) + + +async def test_set_value_explicit_content_type(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that an explicitly provided content type overrides the automatically inferred one.""" + test_key = 'test-explicit-content-type' + test_value = 'Hello, world!' + explicit_content_type = 'text/html; charset=utf-8' + + await kvs_client.set_value(key=test_key, value=test_value, content_type=explicit_content_type) + + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.content_type == explicit_content_type + + +async def test_get_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that attempting to get a non-existent key returns None.""" + record = await kvs_client.get_value(key='nonexistent-key') + assert record is None + + +async def test_overwrite_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that an existing value can be overwritten and the updated value is retrieved correctly.""" + test_key = 'test-overwrite' + + # Set initial value + initial_value = 'Initial value' + await kvs_client.set_value(key=test_key, value=initial_value) + + # Overwrite with new value + new_value = 'New value' + await kvs_client.set_value(key=test_key, value=new_value) + + # Verify the updated value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.value == new_value + + +async def test_delete_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that deleting a value removes its files from disk and makes it irretrievable.""" + test_key = 'test-delete' + test_value = 'Delete me' + + # Set a value + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify it exists + key_path = kvs_client.path_to_kvs / test_key + metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert metadata_path.exists() + + # Delete the value + await kvs_client.delete_value(key=test_key) + + # Verify files were deleted + assert not key_path.exists() + assert not metadata_path.exists() + + # Verify value is no longer retrievable + record = await kvs_client.get_value(key=test_key) + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent-key') + + +async def test_iterate_keys_empty_store(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that iterating over an empty store yields no keys.""" + keys = [key async for key in kvs_client.iterate_keys()] + assert len(keys) == 0 + + +async def test_iterate_keys(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that all keys can be iterated over and are returned in sorted order.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate over keys + keys = [key.key async for key in kvs_client.iterate_keys()] + assert len(keys) == 3 + assert sorted(keys) == ['key1', 'key2', 'key3'] + + +async def test_iterate_keys_with_limit(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that the limit parameter returns only the specified number of keys.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate with limit + keys = [key.key async for key in kvs_client.iterate_keys(limit=2)] + assert len(keys) == 2 + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" + # Add some values with alphabetical keys + await kvs_client.set_value(key='a-key', value='value-a') + await kvs_client.set_value(key='b-key', value='value-b') + await kvs_client.set_value(key='c-key', value='value-c') + await kvs_client.set_value(key='d-key', value='value-d') + + # Iterate with exclusive start key + keys = [key.key async for key in kvs_client.iterate_keys(exclusive_start_key='b-key')] + assert len(keys) == 2 + assert 'c-key' in keys + assert 'd-key' in keys + assert 'a-key' not in keys + assert 'b-key' not in keys + + +async def test_drop(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that drop removes the entire store directory from disk.""" + await kvs_client.set_value(key='test', value='test-value') + + assert kvs_client.metadata.name in FileSystemKeyValueStoreClient._cache_by_name + assert kvs_client.path_to_kvs.exists() + + # Drop the store + await kvs_client.drop() + + assert kvs_client.metadata.name not in FileSystemKeyValueStoreClient._cache_by_name + assert not kvs_client.path_to_kvs.exists() + + +async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified + + accessed_after_get = kvs_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await kvs_client.set_value(key='new-key', value='new-value') + + # Verify timestamps again + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get + + +async def test_get_public_url_not_supported(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError for the file system implementation.""" + with pytest.raises(NotImplementedError, match='Public URLs are not supported'): + await kvs_client.get_public_url(key='any-key') + + +async def test_concurrent_operations(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that multiple concurrent set operations can be performed safely with correct results.""" + + # Create multiple tasks to set different values concurrently + async def set_value(key: str, value: str) -> None: + await kvs_client.set_value(key=key, value=value) + + tasks = [asyncio.create_task(set_value(f'concurrent-key-{i}', f'value-{i}')) for i in range(10)] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + # Verify all values were set correctly + for i in range(10): + key = f'concurrent-key-{i}' + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.value == f'value-{i}' diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py new file mode 100644 index 0000000000..125e60f9b7 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: + """A fixture for a file system request queue client.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_rq(configuration: Configuration) -> None: + """Test that open() creates a new request queue with proper metadata and files on disk.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='new_request_queue', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_request_queue' + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_rq.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_request_queue' + + +async def test_open_existing_rq( + rq_client: FileSystemRequestQueueClient, + configuration: Configuration, +) -> None: + """Test that open() loads an existing request queue correctly.""" + configuration.purge_on_start = False + + # Add a request to the original client + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Open the same request queue again + reopened_client = await FileSystemStorageClient().open_request_queue_client( + name=rq_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert rq_client.metadata.id == reopened_client.metadata.id + assert rq_client.metadata.name == reopened_client.metadata.name + assert rq_client.metadata.total_request_count == 1 + assert rq_client.metadata.pending_request_count == 1 + + # Verify clients (python) ids - should be the same object due to caching + assert id(rq_client) == id(reopened_client) + + +async def test_rq_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the request queue.""" + configuration.purge_on_start = True + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify request was added + assert rq_client1.metadata.total_request_count == 1 + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + + # Verify data was purged + assert rq_client2.metadata.total_request_count == 0 + + +async def test_rq_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the request queue.""" + configuration.purge_on_start = False + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + + # Verify data was preserved + assert rq_client2.metadata.total_request_count == 1 + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_request_queue_client(id='some-id', configuration=configuration) + + +@pytest.fixture +def rq_path(rq_client: FileSystemRequestQueueClient) -> Path: + """Return the path to the request queue directory.""" + return rq_client.path_to_rq + + +async def test_add_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding requests creates proper files in the filesystem.""" + # Add a batch of requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + Request.from_url('https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify response + assert len(response.processed_requests) == 3 + for i, processed_request in enumerate(response.processed_requests): + assert processed_request.unique_key == f'https://example.com/{i + 1}' + assert processed_request.was_already_present is False + assert processed_request.was_already_handled is False + + # Verify request files were created + request_files = list(rq_client.path_to_rq.glob('*.json')) + assert len(request_files) == 4 # 3 requests + metadata file + assert rq_client.path_to_metadata in request_files + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + # Verify content of the request files + for req_file in [f for f in request_files if f != rq_client.path_to_metadata]: + with req_file.open() as f: + content = json.load(f) + assert 'url' in content + assert content['url'].startswith('https://example.com/') + assert 'id' in content + assert 'handled_at' not in content # Not yet handled + + +async def test_add_duplicate_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding a duplicate request.""" + request = Request.from_url('https://example.com') + + # Add the request the first time + await rq_client.add_batch_of_requests([request]) + + # Add the same request again + second_response = await rq_client.add_batch_of_requests([request]) + + # Verify response indicates it was already present + assert second_response.processed_requests[0].was_already_present is True + + # Verify only one request file exists + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + # Verify metadata counts weren't incremented + assert rq_client.metadata.total_request_count == 1 + assert rq_client.metadata.pending_request_count == 1 + + +async def test_fetch_next_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/1' + + # Check that it's marked as in-progress + assert first_request.id in rq_client._in_progress + + # Fetch the second request + second_request = await rq_client.fetch_next_request() + assert second_request is not None + assert second_request.url == 'https://example.com/2' + + # There should be no more requests + empty_request = await rq_client.fetch_next_request() + assert empty_request is None + + +async def test_fetch_forefront_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test that forefront requests are fetched first.""" + # Add regular requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/regular1'), + Request.from_url('https://example.com/regular2'), + ] + ) + + # Add forefront requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/priority1'), + Request.from_url('https://example.com/priority2'), + ], + forefront=True, + ) + + # Fetch requests - they should come in priority order first + next_request1 = await rq_client.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') + + next_request2 = await rq_client.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + + next_request3 = await rq_client.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/regular') + + next_request4 = await rq_client.fetch_next_request() + assert next_request4 is not None + assert next_request4.url.startswith('https://example.com/regular') + + +async def test_mark_request_as_handled(rq_client: FileSystemRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Mark it as handled + result = await rq_client.mark_request_as_handled(request) + assert result is not None + assert result.was_already_handled is True + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Verify metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Verify the file was updated with handled_at timestamp + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + with request_files[0].open() as f: + content = json.load(f) + assert 'handled_at' in content + assert content['handled_at'] is not None + + +async def test_reclaim_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request that failed processing.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Reclaim the request + result = await rq_client.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Should be able to fetch it again + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + + +async def test_reclaim_request_with_forefront(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request with forefront priority.""" + # Add requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/first'), + Request.from_url('https://example.com/second'), + ] + ) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' + + # Reclaim it with forefront priority + await rq_client.reclaim_request(first_request, forefront=True) + + # Verify it's in the forefront set + assert first_request.id in rq_client._forefront_requests + + # It should be returned before the second request + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.url == 'https://example.com/first' + + +async def test_is_empty(rq_client: FileSystemRequestQueueClient) -> None: + """Test checking if a queue is empty.""" + # Queue should start empty + assert await rq_client.is_empty() is True + + # Add a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + assert await rq_client.is_empty() is False + + # Fetch and handle the request + request = await rq_client.fetch_next_request() + assert request is not None + await rq_client.mark_request_as_handled(request) + + # Queue should be empty again + assert await rq_client.is_empty() is True + + +async def test_get_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + response = await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request_id = response.processed_requests[0].id + + # Get the request by ID + request = await rq_client.get_request(request_id) + assert request is not None + assert request.id == request_id + assert request.url == 'https://example.com' + + # Try to get a non-existent request + not_found = await rq_client.get_request('non-existent-id') + assert not_found is None + + +async def test_drop(configuration: Configuration) -> None: + """Test dropping the queue removes files from the filesystem.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='drop_test', + configuration=configuration, + ) + + # Add requests to create files + await client.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Verify the directory exists + rq_path = client.path_to_rq + assert rq_path.exists() + + # Drop the client + await client.drop() + + # Verify the directory was removed + assert not rq_path.exists() + + # Verify the client was removed from the cache + assert client.metadata.name not in FileSystemRequestQueueClient._cache_by_name + + +async def test_file_persistence(configuration: Configuration) -> None: + """Test that requests are persisted to files and can be recovered after a 'restart'.""" + # Explicitly set purge_on_start to False to ensure files aren't deleted + configuration.purge_on_start = False + + # Create a client and add requests + client1 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + await client1.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Fetch and handle one request + request = await client1.fetch_next_request() + assert request is not None + await client1.mark_request_as_handled(request) + + # Get the storage directory path before clearing the cache + storage_path = client1.path_to_rq + assert storage_path.exists(), 'Request queue directory should exist' + + # Verify files exist + request_files = list(storage_path.glob('*.json')) + assert len(request_files) > 0, 'Request files should exist' + + # Clear cache to simulate process restart + FileSystemRequestQueueClient._cache_by_name.clear() + + # Create a new client with same name (which will load from files) + client2 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + # Verify state was recovered + assert client2.metadata.total_request_count == 2 + assert client2.metadata.handled_request_count == 1 + assert client2.metadata.pending_request_count == 1 + + # Should be able to fetch the remaining request + remaining_request = await client2.fetch_next_request() + assert remaining_request is not None + assert remaining_request.url == 'https://example.com/2' + + # Clean up + await client2.drop() + + +async def test_metadata_updates(rq_client: FileSystemRequestQueueClient) -> None: + """Test that metadata timestamps are updated correctly after operations.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await rq_client.is_empty() + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.accessed_at > initial_accessed + assert rq_client.metadata.modified_at == initial_modified + + accessed_after_get = rq_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify timestamps again + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_creation_management.py b/tests/unit/storage_clients/_memory/test_creation_management.py deleted file mode 100644 index 88a5e9e283..0000000000 --- a/tests/unit/storage_clients/_memory/test_creation_management.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from unittest.mock import AsyncMock, patch - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._memory._creation_management import persist_metadata_if_enabled - - -async def test_persist_metadata_skips_when_disabled(tmp_path: Path) -> None: - await persist_metadata_if_enabled(data={'key': 'value'}, entity_directory=str(tmp_path), write_metadata=False) - assert not list(tmp_path.iterdir()) # The directory should be empty since write_metadata is False - - -async def test_persist_metadata_creates_files_and_directories_when_enabled(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'new_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - assert entity_directory.exists() is True # Check if directory was created - assert (entity_directory / METADATA_FILENAME).is_file() # Check if file was created - - -async def test_persist_metadata_correctly_writes_data(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'data_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - metadata_path = entity_directory / METADATA_FILENAME - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == data # Check if correct data was written - - -async def test_persist_metadata_rewrites_data_with_error(tmp_path: Path) -> None: - init_data = {'key': 'very_long_value'} - update_data = {'key': 'short_value'} - error_data = {'key': 'error'} - - entity_directory = Path(tmp_path, 'data_dir') - metadata_path = entity_directory / METADATA_FILENAME - - # write metadata with init_data - await persist_metadata_if_enabled(data=init_data, entity_directory=str(entity_directory), write_metadata=True) - - # rewrite metadata with new_data - await persist_metadata_if_enabled(data=update_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == update_data # Check if correct data was rewritten - - # raise interrupt between opening a file and writing - module_for_patch = 'crawlee.storage_clients._memory._creation_management.json_dumps' - with patch(module_for_patch, AsyncMock(side_effect=KeyboardInterrupt())), pytest.raises(KeyboardInterrupt): - await persist_metadata_if_enabled(data=error_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert content == '' # The file is empty after an error diff --git a/tests/unit/storage_clients/_memory/test_dataset_client.py b/tests/unit/storage_clients/_memory/test_dataset_client.py deleted file mode 100644 index 472d11a8b3..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_client.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetClient - - -@pytest.fixture -async def dataset_client(memory_storage_client: MemoryStorageClient) -> DatasetClient: - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - return memory_storage_client.dataset(dataset_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - dataset_client = memory_storage_client.dataset(id='nonexistent-id') - assert await dataset_client.get() is None - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.update(name='test-update') - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.list_items() - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.push_items([{'abc': 123}]) - await dataset_client.delete() - - -async def test_not_implemented(dataset_client: DatasetClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.stream_items() - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.get_items_as_bytes() - - -async def test_get(dataset_client: DatasetClient) -> None: - await asyncio.sleep(0.1) - info = await dataset_client.get() - assert info is not None - assert info.id == dataset_client.id - assert info.accessed_at != info.created_at - - -async def test_update(dataset_client: DatasetClient) -> None: - new_dataset_name = 'test-update' - await dataset_client.push_items({'abc': 123}) - - old_dataset_info = await dataset_client.get() - assert old_dataset_info is not None - old_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, old_dataset_info.name or '') - new_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is True - assert (new_dataset_directory / '000000001.json').exists() is False - - await asyncio.sleep(0.1) - updated_dataset_info = await dataset_client.update(name=new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is False - assert (new_dataset_directory / '000000001.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_dataset_info.created_at == updated_dataset_info.created_at - assert old_dataset_info.modified_at != updated_dataset_info.modified_at - assert old_dataset_info.accessed_at != updated_dataset_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Dataset with name "test-update" already exists.'): - await dataset_client.update(name=new_dataset_name) - - -async def test_delete(dataset_client: DatasetClient) -> None: - await dataset_client.push_items({'abc': 123}) - dataset_info = await dataset_client.get() - assert dataset_info is not None - dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, dataset_info.name or '') - assert (dataset_directory / '000000001.json').exists() is True - await dataset_client.delete() - assert (dataset_directory / '000000001.json').exists() is False - # Does not crash when called again - await dataset_client.delete() - - -async def test_push_items(dataset_client: DatasetClient) -> None: - await dataset_client.push_items('{"test": "JSON from a string"}') - await dataset_client.push_items({'abc': {'def': {'ghi': '123'}}}) - await dataset_client.push_items(['{"test-json-parse": "JSON from a string"}' for _ in range(10)]) - await dataset_client.push_items([{'test-dict': i} for i in range(10)]) - - list_page = await dataset_client.list_items() - assert list_page.items[0]['test'] == 'JSON from a string' - assert list_page.items[1]['abc']['def']['ghi'] == '123' - assert list_page.items[11]['test-json-parse'] == 'JSON from a string' - assert list_page.items[21]['test-dict'] == 9 - assert list_page.count == 22 - - -async def test_list_items(dataset_client: DatasetClient) -> None: - item_count = 100 - used_offset = 10 - used_limit = 50 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - # Test without any parameters - list_default = await dataset_client.list_items() - assert list_default.count == item_count - assert list_default.offset == 0 - assert list_default.items[0]['id'] == 0 - assert list_default.desc is False - # Test offset - list_offset_10 = await dataset_client.list_items(offset=used_offset) - assert list_offset_10.count == item_count - used_offset - assert list_offset_10.offset == used_offset - assert list_offset_10.total == item_count - assert list_offset_10.items[0]['id'] == used_offset - # Test limit - list_limit_50 = await dataset_client.list_items(limit=used_limit) - assert list_limit_50.count == used_limit - assert list_limit_50.limit == used_limit - assert list_limit_50.total == item_count - # Test desc - list_desc_true = await dataset_client.list_items(desc=True) - assert list_desc_true.items[0]['id'] == 99 - assert list_desc_true.desc is True - - -async def test_iterate_items(dataset_client: DatasetClient) -> None: - item_count = 100 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - actual_items = [] - async for item in dataset_client.iterate_items(): - assert 'id' in item - actual_items.append(item) - assert len(actual_items) == item_count - assert actual_items[0]['id'] == 0 - assert actual_items[99]['id'] == 99 - - -async def test_reuse_dataset(dataset_client: DatasetClient, memory_storage_client: MemoryStorageClient) -> None: - item_count = 10 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - - memory_storage_client.datasets_handled = [] # purge datasets loaded to test create_dataset_from_directory - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - assert dataset_info.item_count == item_count diff --git a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py b/tests/unit/storage_clients/_memory/test_dataset_collection_client.py deleted file mode 100644 index d71b7e8f68..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetCollectionClient - - -@pytest.fixture -def datasets_client(memory_storage_client: MemoryStorageClient) -> DatasetCollectionClient: - return memory_storage_client.datasets() - - -async def test_get_or_create(datasets_client: DatasetCollectionClient) -> None: - dataset_name = 'test' - # A new dataset gets created - dataset_info = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.name == dataset_name - - # Another get_or_create call returns the same dataset - dataset_info_existing = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.id == dataset_info_existing.id - assert dataset_info.name == dataset_info_existing.name - assert dataset_info.created_at == dataset_info_existing.created_at - - -async def test_list(datasets_client: DatasetCollectionClient) -> None: - dataset_list_1 = await datasets_client.list() - assert dataset_list_1.count == 0 - - dataset_info = await datasets_client.get_or_create(name='dataset') - dataset_list_2 = await datasets_client.list() - - assert dataset_list_2.count == 1 - assert dataset_list_2.items[0].name == dataset_info.name - - # Test sorting behavior - newer_dataset_info = await datasets_client.get_or_create(name='newer-dataset') - dataset_list_sorting = await datasets_client.list() - assert dataset_list_sorting.count == 2 - assert dataset_list_sorting.items[0].name == dataset_info.name - assert dataset_list_sorting.items[1].name == newer_dataset_info.name diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_client.py deleted file mode 100644 index 26d1f8f974..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_client.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import json -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreClient - -TINY_PNG = base64.b64decode( - s='iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=', -) -TINY_BYTES = b'\x12\x34\x56\x78\x90\xab\xcd\xef' -TINY_DATA = {'a': 'b'} -TINY_TEXT = 'abcd' - - -@pytest.fixture -async def key_value_store_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreClient: - key_value_stores_client = memory_storage_client.key_value_stores() - kvs_info = await key_value_stores_client.get_or_create(name='test') - return memory_storage_client.key_value_store(kvs_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - kvs_client = memory_storage_client.key_value_store(id='nonexistent-id') - assert await kvs_client.get() is None - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.update(name='test-update') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.list_keys() - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.set_record('test', {'abc': 123}) - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record_as_bytes('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.delete_record('test') - - await kvs_client.delete() - - -async def test_not_implemented(key_value_store_client: KeyValueStoreClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await key_value_store_client.stream_record('test') - - -async def test_get(key_value_store_client: KeyValueStoreClient) -> None: - await asyncio.sleep(0.1) - info = await key_value_store_client.get() - assert info is not None - assert info.id == key_value_store_client.id - assert info.accessed_at != info.created_at - - -async def test_update(key_value_store_client: KeyValueStoreClient) -> None: - new_kvs_name = 'test-update' - await key_value_store_client.set_record('test', {'abc': 123}) - old_kvs_info = await key_value_store_client.get() - assert old_kvs_info is not None - old_kvs_directory = Path( - key_value_store_client._memory_storage_client.key_value_stores_directory, old_kvs_info.name or '' - ) - new_kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is True - assert (new_kvs_directory / 'test.json').exists() is False - - await asyncio.sleep(0.1) - updated_kvs_info = await key_value_store_client.update(name=new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is False - assert (new_kvs_directory / 'test.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_kvs_info.created_at == updated_kvs_info.created_at - assert old_kvs_info.modified_at != updated_kvs_info.modified_at - assert old_kvs_info.accessed_at != updated_kvs_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Key-value store with name "test-update" already exists.'): - await key_value_store_client.update(name=new_kvs_name) - - -async def test_delete(key_value_store_client: KeyValueStoreClient) -> None: - await key_value_store_client.set_record('test', {'abc': 123}) - kvs_info = await key_value_store_client.get() - assert kvs_info is not None - kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, kvs_info.name or '') - assert (kvs_directory / 'test.json').exists() is True - await key_value_store_client.delete() - assert (kvs_directory / 'test.json').exists() is False - # Does not crash when called again - await key_value_store_client.delete() - - -async def test_list_keys_empty(key_value_store_client: KeyValueStoreClient) -> None: - keys = await key_value_store_client.list_keys() - assert len(keys.items) == 0 - assert keys.count == 0 - assert keys.is_truncated is False - - -async def test_list_keys(key_value_store_client: KeyValueStoreClient) -> None: - record_count = 4 - used_limit = 2 - used_exclusive_start_key = 'a' - await key_value_store_client.set_record('b', 'test') - await key_value_store_client.set_record('a', 'test') - await key_value_store_client.set_record('d', 'test') - await key_value_store_client.set_record('c', 'test') - - # Default settings - keys = await key_value_store_client.list_keys() - assert keys.items[0].key == 'a' - assert keys.items[3].key == 'd' - assert keys.count == record_count - assert keys.is_truncated is False - # Test limit - keys_limit_2 = await key_value_store_client.list_keys(limit=used_limit) - assert keys_limit_2.count == record_count - assert keys_limit_2.limit == used_limit - assert keys_limit_2.items[1].key == 'b' - # Test exclusive start key - keys_exclusive_start = await key_value_store_client.list_keys(exclusive_start_key=used_exclusive_start_key, limit=2) - assert keys_exclusive_start.exclusive_start_key == used_exclusive_start_key - assert keys_exclusive_start.is_truncated is True - assert keys_exclusive_start.next_exclusive_start_key == 'c' - assert keys_exclusive_start.items[0].key == 'b' - assert keys_exclusive_start.items[-1].key == keys_exclusive_start.next_exclusive_start_key - - -async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyValueStoreClient) -> None: - # Test setting dict record - dict_record_key = 'test-dict' - await key_value_store_client.set_record(dict_record_key, {'test': 123}) - dict_record_info = await key_value_store_client.get_record(dict_record_key) - assert dict_record_info is not None - assert 'application/json' in str(dict_record_info.content_type) - assert dict_record_info.value['test'] == 123 - - # Test setting str record - str_record_key = 'test-str' - await key_value_store_client.set_record(str_record_key, 'test') - str_record_info = await key_value_store_client.get_record(str_record_key) - assert str_record_info is not None - assert 'text/plain' in str(str_record_info.content_type) - assert str_record_info.value == 'test' - - # Test setting explicit json record but use str as value, i.e. json dumps is skipped - explicit_json_key = 'test-json' - await key_value_store_client.set_record(explicit_json_key, '{"test": "explicit string"}', 'application/json') - bytes_record_info = await key_value_store_client.get_record(explicit_json_key) - assert bytes_record_info is not None - assert 'application/json' in str(bytes_record_info.content_type) - assert bytes_record_info.value['test'] == 'explicit string' - - # Test using bytes - bytes_key = 'test-json' - bytes_value = b'testing bytes set_record' - await key_value_store_client.set_record(bytes_key, bytes_value, 'unknown') - bytes_record_info = await key_value_store_client.get_record(bytes_key) - assert bytes_record_info is not None - assert 'unknown' in str(bytes_record_info.content_type) - assert bytes_record_info.value == bytes_value - assert bytes_record_info.value.decode('utf-8') == bytes_value.decode('utf-8') - - # Test using file descriptor - with open(tmp_path / 'test.json', 'w+', encoding='utf-8') as f: # noqa: ASYNC230 - f.write('Test') - with pytest.raises(NotImplementedError, match='File-like values are not supported in local memory storage'): - await key_value_store_client.set_record('file', f) - - -async def test_get_record_as_bytes(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - record_value = 'testing' - await key_value_store_client.set_record(record_key, record_value) - record_info = await key_value_store_client.get_record_as_bytes(record_key) - assert record_info is not None - assert record_info.value == record_value.encode('utf-8') - - -async def test_delete_record(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - await key_value_store_client.set_record(record_key, 'test') - await key_value_store_client.delete_record(record_key) - # Does not crash when called again - await key_value_store_client.delete_record(record_key) - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ], -) -async def test_writes_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Get KVS client - kvs_info = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - kvs_client = memory_storage_client.key_value_store(kvs_info.id) - - # Write the test input item to the store - await kvs_client.set_record( - key=input_data['key'], - value=input_data['value'], - content_type=input_data['contentType'], - ) - - # Check that everything was written correctly, both the data and metadata - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - item_path = Path(storage_path, expected_output['filename']) - item_metadata_path = storage_path / f'{expected_output["filename"]}.__metadata__.json' - - assert item_path.exists() - assert item_metadata_path.exists() - - # Test the actual value of the item - with open(item_path, 'rb') as item_file: # noqa: ASYNC230 - actual_value = maybe_parse_body(item_file.read(), expected_output['contentType']) - assert actual_value == input_data['value'] - - # Test the actual metadata of the item - with open(item_metadata_path, encoding='utf-8') as metadata_file: # noqa: ASYNC230 - json_content = json.load(metadata_file) - metadata = KeyValueStoreRecordMetadata(**json_content) - assert metadata.key == expected_output['key'] - assert expected_output['contentType'] in metadata.content_type - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image', - 'value': TINY_PNG, - 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image.png', - 'value': TINY_PNG, - 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, - {'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input.json', - 'value': TINY_DATA, - 'metadata': {'key': 'input', 'contentType': 'application/json'}, - }, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - { - 'filename': 'input.bin', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ], -) -async def test_reads_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Ensure the directory for the store exists - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - storage_path.mkdir(exist_ok=True, parents=True) - - store_metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name='', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - user_id='1', - ) - - # Write the store metadata to disk - storage_metadata_path = storage_path / METADATA_FILENAME - with open(storage_metadata_path, mode='wb') as f: # noqa: ASYNC230 - f.write(store_metadata.model_dump_json().encode('utf-8')) - - # Write the test input item to the disk - item_path = storage_path / input_data['filename'] - with open(item_path, 'wb') as item_file: # noqa: ASYNC230 - if isinstance(input_data['value'], bytes): - item_file.write(input_data['value']) - elif isinstance(input_data['value'], str): - item_file.write(input_data['value'].encode('utf-8')) - else: - s = await json_dumps(input_data['value']) - item_file.write(s.encode('utf-8')) - - # Optionally write the metadata to disk if there is some - if input_data['metadata'] is not None: - storage_metadata_path = storage_path / f'{input_data["filename"]}.__metadata__.json' - with open(storage_metadata_path, 'w', encoding='utf-8') as metadata_file: # noqa: ASYNC230 - s = await json_dumps( - { - 'key': input_data['metadata']['key'], - 'contentType': input_data['metadata']['contentType'], - } - ) - metadata_file.write(s) - - # Create the key-value store client to load the items from disk - store_details = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - key_value_store_client = memory_storage_client.key_value_store(store_details.id) - - # Read the item from the store and check if it is as expected - actual_record = await key_value_store_client.get_record(expected_output['key']) - assert actual_record is not None - - assert actual_record.key == expected_output['key'] - assert actual_record.content_type == expected_output['contentType'] - assert actual_record.value == input_data['value'] diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py deleted file mode 100644 index 41b289eb06..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreCollectionClient - - -@pytest.fixture -def key_value_stores_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreCollectionClient: - return memory_storage_client.key_value_stores() - - -async def test_get_or_create(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - kvs_name = 'test' - # A new kvs gets created - kvs_info = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.name == kvs_name - - # Another get_or_create call returns the same kvs - kvs_info_existing = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.id == kvs_info_existing.id - assert kvs_info.name == kvs_info_existing.name - assert kvs_info.created_at == kvs_info_existing.created_at - - -async def test_list(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - assert (await key_value_stores_client.list()).count == 0 - kvs_info = await key_value_stores_client.get_or_create(name='kvs') - kvs_list = await key_value_stores_client.list() - assert kvs_list.count == 1 - assert kvs_list.items[0].name == kvs_info.name - - # Test sorting behavior - newer_kvs_info = await key_value_stores_client.get_or_create(name='newer-kvs') - kvs_list_sorting = await key_value_stores_client.list() - assert kvs_list_sorting.count == 2 - assert kvs_list_sorting.items[0].name == kvs_info.name - assert kvs_list_sorting.items[1].name == newer_kvs_info.name diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py new file mode 100644 index 0000000000..06da10b5f8 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryDatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + + + +@pytest.fixture +async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: + """Fixture that provides a fresh memory dataset client for each test.""" + client = await MemoryStorageClient().open_dataset_client(name='test_dataset') + yield client + await client.drop() + + +async def test_open_creates_new_dataset() -> None: + """Test that open() creates a new dataset with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_dataset_client(name='new_dataset') + + # Verify correct client type and properties + assert isinstance(client, MemoryDatasetClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify the client was cached + assert 'new_dataset' in MemoryDatasetClient._cache_by_name + + +async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: + """Test that open() loads an existing dataset with matching properties.""" + configuration = Configuration(purge_on_start=False) + + # Open the same dataset again + reopened_client = await MemoryStorageClient().open_dataset_client( + name=dataset_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert dataset_client.metadata.id == reopened_client.metadata.id + assert dataset_client.metadata.name == reopened_client.metadata.name + assert dataset_client.metadata.item_count == reopened_client.metadata.item_count + + # Verify clients (python) ids + assert id(dataset_client) == id(reopened_client) + + +async def test_dataset_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration = Configuration(purge_on_start=True) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + configuration = Configuration(purge_on_start=False) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().open_dataset_client( + name='test_no_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await MemoryStorageClient().open_dataset_client( + name='test_no_purge_dataset', + configuration=configuration, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_dataset_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' + + +async def test_push_data_single_item(dataset_client: MemoryDatasetClient) -> None: + """Test pushing a single item to the dataset and verifying it was stored correctly.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 1 + + # Verify item was stored + result = await dataset_client.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset_client: MemoryDatasetClient) -> None: + """Test pushing multiple items to the dataset and verifying they were stored correctly.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 3 + + # Verify items were stored + result = await dataset_client.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset_client: MemoryDatasetClient) -> None: + """Test that getting data from an empty dataset returns empty results with correct metadata.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: MemoryDatasetClient) -> None: + """Test that all items pushed to the dataset can be retrieved with correct metadata.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: MemoryDatasetClient) -> None: + """Test that offset and limit parameters work correctly for dataset pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: MemoryDatasetClient) -> None: + """Test that the desc parameter correctly reverses the order of returned items.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: MemoryDatasetClient) -> None: + """Test that the skip_empty parameter correctly filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items yields each item in the dataset in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items respects offset, limit, and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(dataset_client: MemoryDatasetClient) -> None: + """Test that drop removes the dataset from cache and resets its state.""" + await dataset_client.push_data({'test': 'data'}) + + # Verify the dataset exists in the cache + assert dataset_client.metadata.name in MemoryDatasetClient._cache_by_name + + # Drop the dataset + await dataset_client.drop() + + # Verify the dataset was removed from the cache + assert dataset_client.metadata.name not in MemoryDatasetClient._cache_by_name + + # Verify the dataset is empty + assert dataset_client.metadata.item_count == 0 + result = await dataset_client.get_data() + assert result.count == 0 + + +async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified + + accessed_after_get = dataset_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py new file mode 100644 index 0000000000..b179b3fd3e --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryKeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreRecordMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + + + +@pytest.fixture +async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: + """Fixture that provides a fresh memory key-value store client for each test.""" + client = await MemoryStorageClient().open_key_value_store_client(name='test_kvs') + yield client + await client.drop() + + +async def test_open_creates_new_kvs() -> None: + """Test that open() creates a new key-value store with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_key_value_store_client(name='new_kvs') + + # Verify correct client type and properties + assert isinstance(client, MemoryKeyValueStoreClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify the client was cached + assert 'new_kvs' in MemoryKeyValueStoreClient._cache_by_name + + +async def test_open_existing_kvs(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that open() loads an existing key-value store with matching properties.""" + configuration = Configuration(purge_on_start=False) + # Open the same key-value store again + reopened_client = await MemoryStorageClient().open_key_value_store_client( + name=kvs_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert kvs_client.metadata.id == reopened_client.metadata.id + assert kvs_client.metadata.name == reopened_client.metadata.name + + # Verify clients (python) ids + assert id(kvs_client) == id(reopened_client) + + +async def test_kvs_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the KVS.""" + configuration = Configuration(purge_on_start=True) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the KVS.""" + configuration = Configuration(purge_on_start=False) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().open_key_value_store_client( + name='test_no_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await MemoryStorageClient().open_key_value_store_client( + name='test_no_purge_kvs', + configuration=configuration, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_key_value_store_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' + + +@pytest.mark.parametrize( + ('key', 'value', 'expected_content_type'), + [ + pytest.param('string_key', 'string value', 'text/plain; charset=utf-8', id='string'), + pytest.param('dict_key', {'name': 'test', 'value': 42}, 'application/json; charset=utf-8', id='dictionary'), + pytest.param('list_key', [1, 2, 3], 'application/json; charset=utf-8', id='list'), + pytest.param('bytes_key', b'binary data', 'application/octet-stream', id='bytes'), + ], +) +async def test_set_get_value( + kvs_client: MemoryKeyValueStoreClient, + key: str, + value: Any, + expected_content_type: str, +) -> None: + """Test storing and retrieving different types of values with correct content types.""" + # Set value + await kvs_client.set_value(key=key, value=value) + + # Get and verify value + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.key == key + assert record.value == value + assert record.content_type == expected_content_type + + +async def test_get_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to get a non-existent key returns None.""" + record = await kvs_client.get_value(key='nonexistent') + assert record is None + + +async def test_set_value_with_explicit_content_type(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that an explicitly provided content type overrides the automatically inferred one.""" + value = 'This could be XML' + content_type = 'application/xml' + + await kvs_client.set_value(key='xml_key', value=value, content_type=content_type) + + record = await kvs_client.get_value(key='xml_key') + assert record is not None + assert record.value == value + assert record.content_type == content_type + + +async def test_delete_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that a stored value can be deleted and is no longer retrievable after deletion.""" + # Set a value + await kvs_client.set_value(key='delete_me', value='to be deleted') + + # Verify it exists + record = await kvs_client.get_value(key='delete_me') + assert record is not None + + # Delete it + await kvs_client.delete_value(key='delete_me') + + # Verify it's gone + record = await kvs_client.get_value(key='delete_me') + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent') + + +async def test_iterate_keys(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that all keys can be iterated over and are returned in sorted order with correct metadata.""" + # Set some values + items = { + 'a_key': 'value A', + 'b_key': 'value B', + 'c_key': 'value C', + 'd_key': 'value D', + } + + for key, value in items.items(): + await kvs_client.set_value(key=key, value=value) + + # Get all keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys()] + + # Verify keys are returned in sorted order + assert len(metadata_list) == 4 + assert [m.key for m in metadata_list] == sorted(items.keys()) + assert all(isinstance(m, KeyValueStoreRecordMetadata) for m in metadata_list) + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get keys starting after 'b_key' + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(exclusive_start_key='b_key')] + + # Verify only keys after 'b_key' are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['c_key', 'd_key', 'e_key'] + + +async def test_iterate_keys_with_limit(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that the limit parameter returns only the specified number of keys.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get first 3 keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(limit=3)] + + # Verify only the first 3 keys are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['a_key', 'b_key', 'c_key'] + + +async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that drop removes the store from cache and clears all data.""" + # Add some values to the store + await kvs_client.set_value(key='test', value='data') + + # Verify the store exists in the cache + assert kvs_client.metadata.name in MemoryKeyValueStoreClient._cache_by_name + + # Drop the store + await kvs_client.drop() + + # Verify the store was removed from the cache + assert kvs_client.metadata.name not in MemoryKeyValueStoreClient._cache_by_name + + # Verify the store is empty + record = await kvs_client.get_value(key='test') + assert record is None + + +async def test_get_public_url(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError for the memory implementation.""" + with pytest.raises(NotImplementedError): + await kvs_client.get_public_url(key='any-key') + + +async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified + + accessed_after_get = kvs_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + await kvs_client.set_value(key='new_key', value='new value') + + # Verify timestamps again + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py new file mode 100644 index 0000000000..f5b6c16adb --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + + + +@pytest.fixture +async def rq_client() -> AsyncGenerator[MemoryRequestQueueClient, None]: + """Fixture that provides a fresh memory request queue client for each test.""" + client = await MemoryStorageClient().open_request_queue_client(name='test_rq') + yield client + await client.drop() + + +async def test_open_creates_new_rq() -> None: + """Test that open() creates a new request queue with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_request_queue_client(name='new_rq') + + # Verify correct client type and properties + assert isinstance(client, MemoryRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_rq' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert client.metadata.had_multiple_clients is False + + # Verify the client was cached + assert 'new_rq' in MemoryRequestQueueClient._cache_by_name + + +async def test_open_existing_rq(rq_client: MemoryRequestQueueClient) -> None: + """Test that open() loads an existing request queue with matching properties.""" + configuration = Configuration(purge_on_start=False) + # Open the same request queue again + reopened_client = await MemoryStorageClient().open_request_queue_client( + name=rq_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert rq_client.metadata.id == reopened_client.metadata.id + assert rq_client.metadata.name == reopened_client.metadata.name + + # Verify clients (python) ids + assert id(rq_client) == id(reopened_client) + + +async def test_rq_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the RQ.""" + configuration = Configuration(purge_on_start=True) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/initial') + await rq_client1.add_batch_of_requests([request]) + + # Verify request was added + assert await rq_client1.is_empty() is False + + # Reopen + rq_client2 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + + # Verify queue was purged + assert await rq_client2.is_empty() is True + + +async def test_rq_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the RQ.""" + configuration = Configuration(purge_on_start=False) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().open_request_queue_client( + name='test_no_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/preserved') + await rq_client1.add_batch_of_requests([request]) + + # Reopen + rq_client2 = await MemoryStorageClient().open_request_queue_client( + name='test_no_purge_rq', + configuration=configuration, + ) + + # Verify request was preserved + assert await rq_client2.is_empty() is False + next_request = await rq_client2.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/preserved' + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_request_queue_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id is not None # ID is always auto-generated + assert client.metadata.name == 'some-name' + + +async def test_add_batch_of_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test adding a batch of requests to the queue.""" + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + Request.from_url(url='https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify correct response + assert len(response.processed_requests) == 3 + assert len(response.unprocessed_requests) == 0 + + # Verify each request was processed correctly + for i, req in enumerate(requests): + assert response.processed_requests[i].id == req.id + assert response.processed_requests[i].unique_key == req.unique_key + assert response.processed_requests[i].was_already_present is False + assert response.processed_requests[i].was_already_handled is False + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + +async def test_add_batch_of_requests_with_duplicates(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests with duplicate unique keys.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1', unique_key='key1'), + Request.from_url(url='https://example.com/2', unique_key='key2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Mark first request as handled + req1 = await rq_client.fetch_next_request() + assert req1 is not None + await rq_client.mark_request_as_handled(req1) + + # Add duplicate requests + duplicate_requests = [ + Request.from_url(url='https://example.com/1-dup', unique_key='key1'), # Same as first (handled) + Request.from_url(url='https://example.com/2-dup', unique_key='key2'), # Same as second (not handled) + Request.from_url(url='https://example.com/3', unique_key='key3'), # New request + ] + response = await rq_client.add_batch_of_requests(duplicate_requests) + + # Verify response + assert len(response.processed_requests) == 3 + + # First request should be marked as already handled + assert response.processed_requests[0].was_already_present is True + assert response.processed_requests[0].was_already_handled is True + + # Second request should be marked as already present but not handled + assert response.processed_requests[1].was_already_present is True + assert response.processed_requests[1].was_already_handled is False + + # Third request should be new + assert response.processed_requests[2].was_already_present is False + assert response.processed_requests[2].was_already_handled is False + + +async def test_add_batch_of_requests_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests to the forefront of the queue.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Add new requests to forefront + forefront_requests = [ + Request.from_url(url='https://example.com/priority'), + ] + await rq_client.add_batch_of_requests(forefront_requests, forefront=True) + + # The priority request should be fetched first + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/priority' + + +async def test_fetch_next_request(rq_client: MemoryRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add some requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/1' + + # Fetch second request + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # No more requests + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_fetch_skips_handled_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips handled requests.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch and handle first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + await rq_client.mark_request_as_handled(request1) + + # Next fetch should return second request, not the handled one + request = await rq_client.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/2' + + +async def test_fetch_skips_in_progress_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips requests that are already in progress.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request (it should be in progress now) + request1 = await rq_client.fetch_next_request() + assert request1 is not None + + # Next fetch should return second request, not the in-progress one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Third fetch should return None as all requests are in progress + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_get_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Get the request by ID + retrieved_request = await rq_client.get_request(request.id) + assert retrieved_request is not None + assert retrieved_request.id == request.id + assert retrieved_request.url == request.url + + # Try to get a non-existent request + nonexistent = await rq_client.get_request('nonexistent-id') + assert nonexistent is None + + +async def test_get_in_progress_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting an in-progress request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Get the request by ID + retrieved = await rq_client.get_request(request.id) + assert retrieved is not None + assert retrieved.id == request.id + assert retrieved.url == request.url + + +async def test_mark_request_as_handled(rq_client: MemoryRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Mark as handled + result = await rq_client.mark_request_as_handled(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is True + + # Check that metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Try to mark again (should fail as it's no longer in-progress) + result = await rq_client.mark_request_as_handled(fetched) + assert result is None + + +async def test_reclaim_request(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request back to the queue.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Reclaim the request + result = await rq_client.reclaim_request(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is False + + # It should be available to fetch again + reclaimed = await rq_client.fetch_next_request() + assert reclaimed is not None + assert reclaimed.id == fetched.id + + +async def test_reclaim_request_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request to the forefront of the queue.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the second request to make it in-progress + await rq_client.fetch_next_request() # Skip the first one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Reclaim the request to forefront + await rq_client.reclaim_request(request2, forefront=True) + + # It should now be the first in the queue + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/2' + + +async def test_is_empty(rq_client: MemoryRequestQueueClient) -> None: + """Test checking if the queue is empty.""" + # Initially empty + assert await rq_client.is_empty() is True + + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Not empty now + assert await rq_client.is_empty() is False + + # Fetch and handle + fetched = await rq_client.fetch_next_request() + assert fetched is not None + await rq_client.mark_request_as_handled(fetched) + + # Empty again (all requests handled) + assert await rq_client.is_empty() is True + + +async def test_is_empty_with_in_progress(rq_client: MemoryRequestQueueClient) -> None: + """Test that in-progress requests don't affect is_empty.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch but don't handle + await rq_client.fetch_next_request() + + # Queue should still be considered non-empty + # This is because the request hasn't been handled yet + assert await rq_client.is_empty() is False + + +async def test_drop(rq_client: MemoryRequestQueueClient) -> None: + """Test that drop removes the queue from cache and clears all data.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Verify the queue exists in the cache + assert rq_client.metadata.name in MemoryRequestQueueClient._cache_by_name + + # Drop the queue + await rq_client.drop() + + # Verify the queue was removed from the cache + assert rq_client.metadata.name not in MemoryRequestQueueClient._cache_by_name + + # Verify the queue is empty + assert await rq_client.is_empty() is True + + +async def test_metadata_updates(rq_client: MemoryRequestQueueClient) -> None: + """Test that operations properly update metadata timestamps.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > initial_accessed + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Record timestamps after add + accessed_after_add = rq_client.metadata.accessed_at + modified_after_add = rq_client.metadata.modified_at + + # Check is_empty (should only update accessed_at) + await rq_client.is_empty() + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Verify only accessed_at changed + assert rq_client.metadata.modified_at == modified_after_add + assert rq_client.metadata.accessed_at > accessed_after_add + + +async def test_unique_key_generation(rq_client: MemoryRequestQueueClient) -> None: + """Test that unique keys are auto-generated if not provided.""" + # Add requests without explicit unique keys + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/1', always_enqueue=True) + ] + response = await rq_client.add_batch_of_requests(requests) + + # Both should be added as their auto-generated unique keys will differ + assert len(response.processed_requests) == 2 + assert all(not pr.was_already_present for pr in response.processed_requests) + + # Add a request with explicit unique key + request = Request.from_url(url='https://example.com/2', unique_key='explicit-key') + await rq_client.add_batch_of_requests([request]) + + # Add duplicate with same unique key + duplicate = Request.from_url(url='https://example.com/different', unique_key='explicit-key') + duplicate_response = await rq_client.add_batch_of_requests([duplicate]) + + # Should be marked as already present + assert duplicate_response.processed_requests[0].was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_client.py b/tests/unit/storage_clients/_memory/test_memory_storage_client.py deleted file mode 100644 index 0d043322ae..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_client.py +++ /dev/null @@ -1,288 +0,0 @@ -# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed -# https://github.com/apify/crawlee-python/issues/146 - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from crawlee import Request, service_locator -from crawlee._consts import METADATA_FILENAME -from crawlee.configuration import Configuration -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storage_clients.models import BatchRequestsOperationResponse - - -async def test_write_metadata(tmp_path: Path) -> None: - dataset_name = 'test' - dataset_no_metadata_name = 'test-no-metadata' - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=True, - ), - ) - ms_no_metadata = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=False, - ) - ) - datasets_client = ms.datasets() - datasets_no_metadata_client = ms_no_metadata.datasets() - await datasets_client.get_or_create(name=dataset_name) - await datasets_no_metadata_client.get_or_create(name=dataset_no_metadata_name) - assert Path(ms.datasets_directory, dataset_name, METADATA_FILENAME).exists() is True - assert Path(ms_no_metadata.datasets_directory, dataset_no_metadata_name, METADATA_FILENAME).exists() is False - - -@pytest.mark.parametrize( - 'persist_storage', - [ - True, - False, - ], -) -async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: # noqa: FBT001 - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=persist_storage, - ) - ) - - # Key value stores - kvs_client = ms.key_value_stores() - kvs_info = await kvs_client.get_or_create(name='kvs') - await ms.key_value_store(kvs_info.id).set_record('test', {'x': 1}, 'application/json') - - path = Path(ms.key_value_stores_directory) / (kvs_info.name or '') / 'test.json' - assert path.exists() is persist_storage - - # Request queues - rq_client = ms.request_queues() - rq_info = await rq_client.get_or_create(name='rq') - - request = Request.from_url('http://lorem.com') - await ms.request_queue(rq_info.id).add_request(request) - - path = Path(ms.request_queues_directory) / (rq_info.name or '') / f'{request.id}.json' - assert path.exists() is persist_storage - - # Datasets - ds_client = ms.datasets() - ds_info = await ds_client.get_or_create(name='ds') - - await ms.dataset(ds_info.id).push_items([{'foo': 'bar'}]) - - -def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.persist_storage is False - - -def test_persist_storage_set_to_false_via_numeric_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', '0') - ms = MemoryStorageClient.from_config(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] - assert ms.persist_storage is False - - -def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=True, - ) - ) - assert ms.persist_storage is True - - -def test_default_write_metadata_behavior(tmp_path: Path) -> None: - # Default behavior - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is True - - -def test_write_metadata_set_to_false_via_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - # Test if env var changes write_metadata to False - monkeypatch.setenv('CRAWLEE_WRITE_METADATA', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is False - - -def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Path) -> None: - # Test if constructor arg takes precedence over env var value - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=False, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - assert ms.write_metadata is False - - -async def test_purge_datasets(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default datasets - datasets_client = ms.datasets() - default_dataset_info = await datasets_client.get_or_create(name='default') - non_default_dataset_info = await datasets_client.get_or_create(name='non-default') - - # Check all folders inside datasets directory before and after purge - assert default_dataset_info.name is not None - assert non_default_dataset_info.name is not None - - default_path = Path(ms.datasets_directory, default_dataset_info.name) - non_default_path = Path(ms.datasets_directory, non_default_dataset_info.name) - - assert default_path.exists() is True - assert non_default_path.exists() is True - - await ms._purge_default_storages() - - assert default_path.exists() is False - assert non_default_path.exists() is True - - -async def test_purge_key_value_stores(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - - # Create default and non-default key-value stores - kvs_client = ms.key_value_stores() - default_kvs_info = await kvs_client.get_or_create(name='default') - non_default_kvs_info = await kvs_client.get_or_create(name='non-default') - default_kvs_client = ms.key_value_store(default_kvs_info.id) - # INPUT.json should be kept - await default_kvs_client.set_record('INPUT', {'abc': 123}, 'application/json') - # test.json should not be kept - await default_kvs_client.set_record('test', {'abc': 123}, 'application/json') - - # Check all folders and files inside kvs directory before and after purge - assert default_kvs_info.name is not None - assert non_default_kvs_info.name is not None - - default_kvs_path = Path(ms.key_value_stores_directory, default_kvs_info.name) - non_default_kvs_path = Path(ms.key_value_stores_directory, non_default_kvs_info.name) - kvs_directory = Path(ms.key_value_stores_directory, 'default') - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is True - - await ms._purge_default_storages() - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is False - - -async def test_purge_request_queues(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default request queues - rq_client = ms.request_queues() - default_rq_info = await rq_client.get_or_create(name='default') - non_default_rq_info = await rq_client.get_or_create(name='non-default') - - # Check all folders inside rq directory before and after purge - assert default_rq_info.name - assert non_default_rq_info.name - - default_rq_path = Path(ms.request_queues_directory, default_rq_info.name) - non_default_rq_path = Path(ms.request_queues_directory, non_default_rq_info.name) - - assert default_rq_path.exists() is True - assert non_default_rq_path.exists() is True - - await ms._purge_default_storages() - - assert default_rq_path.exists() is False - assert non_default_rq_path.exists() is True - - -async def test_not_implemented_method(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - ddt = ms.dataset('test') - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - -async def test_default_storage_path_used(monkeypatch: pytest.MonkeyPatch) -> None: - # Reset the configuration in service locator - service_locator._configuration = None - service_locator._configuration_was_retrieved = False - - # Remove the env var for setting the storage directory - monkeypatch.delenv('CRAWLEE_STORAGE_DIR', raising=False) - - # Initialize the service locator with default configuration - msc = MemoryStorageClient.from_config() - assert msc.storage_dir == './storage' - - -async def test_storage_path_from_env_var_overrides_default(monkeypatch: pytest.MonkeyPatch) -> None: - # We expect the env var to override the default value - monkeypatch.setenv('CRAWLEE_STORAGE_DIR', './env_var_storage_dir') - service_locator.set_configuration(Configuration()) - ms = MemoryStorageClient.from_config() - assert ms.storage_dir == './env_var_storage_dir' - - -async def test_parametrized_storage_path_overrides_env_var() -> None: - # We expect the parametrized value to be used - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir='./parametrized_storage_dir'), # type: ignore[call-arg] - ) - assert ms.storage_dir == './parametrized_storage_dir' - - -async def test_batch_requests_operation_response() -> None: - """Test that `BatchRequestsOperationResponse` creation from example responses.""" - process_request = { - 'requestId': 'EAaArVRs5qV39C9', - 'uniqueKey': 'https://example.com', - 'wasAlreadyHandled': False, - 'wasAlreadyPresent': True, - } - unprocess_request_full = {'uniqueKey': 'https://example2.com', 'method': 'GET', 'url': 'https://example2.com'} - unprocess_request_minimal = {'uniqueKey': 'https://example3.com', 'url': 'https://example3.com'} - BatchRequestsOperationResponse.model_validate( - { - 'processedRequests': [process_request], - 'unprocessedRequests': [unprocess_request_full, unprocess_request_minimal], - } - ) diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py b/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py deleted file mode 100644 index c79fa66792..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Callable - -import pytest - -from crawlee import Request, service_locator -from crawlee.storages._key_value_store import KeyValueStore -from crawlee.storages._request_queue import RequestQueue - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_key_value_store_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - # Store old storage client so we have the object reference for comparison - old_client = service_locator.get_storage_client() - - old_default_kvs = await KeyValueStore.open() - old_non_default_kvs = await KeyValueStore.open(name='non-default') - # Create data in default and non-default key-value store - await old_default_kvs.set_value('test', 'default value') - await old_non_default_kvs.set_value('test', 'non-default value') - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Check if we're using a different memory storage instance - assert old_client is not service_locator.get_storage_client() - default_kvs = await KeyValueStore.open() - assert default_kvs is not old_default_kvs - non_default_kvs = await KeyValueStore.open(name='non-default') - assert non_default_kvs is not old_non_default_kvs - default_value = await default_kvs.get_value('test') - - if purge_on_start: - assert default_value is None - else: - assert default_value == 'default value' - - assert await non_default_kvs.get_value('test') == 'non-default value' - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_request_queue_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - - # Add some requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6): - # [0, 3] <- nothing special - # [1, 4] <- forefront=True - # [2, 5] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Add some more requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6, 12): - # [6, 9] <- nothing special - # [7, 10] <- forefront=True - # [8, 11] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - queue_info = await default_queue.get_info() - assert queue_info is not None - - # If the queue was purged between the runs, only the requests from the second run should be present, - # in the right order - if purge_on_start: - assert queue_info.total_request_count == 6 - assert queue_info.handled_request_count == 2 - - expected_pending_request_order = [10, 7, 6, 9] - # If the queue was NOT purged between the runs, all the requests should be in the queue in the right order - else: - assert queue_info.total_request_count == 12 - assert queue_info.handled_request_count == 4 - - expected_pending_request_order = [10, 7, 4, 1, 0, 3, 6, 9] - - actual_requests = list[Request]() - while req := await default_queue.fetch_next_request(): - actual_requests.append(req) - - assert [int(req.unique_key) for req in actual_requests] == expected_pending_request_order - assert [req.url for req in actual_requests] == [f'http://example.com/{req.unique_key}' for req in actual_requests] - assert [req.payload for req in actual_requests] == [b'test' for _ in actual_requests] diff --git a/tests/unit/storage_clients/_memory/test_request_queue_client.py b/tests/unit/storage_clients/_memory/test_request_queue_client.py deleted file mode 100644 index feffacbbd8..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_client.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import asyncio -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee import Request -from crawlee._request import RequestState - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueClient - - -@pytest.fixture -async def request_queue_client(memory_storage_client: MemoryStorageClient) -> RequestQueueClient: - request_queues_client = memory_storage_client.request_queues() - rq_info = await request_queues_client.get_or_create(name='test') - return memory_storage_client.request_queue(rq_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - request_queue_client = memory_storage_client.request_queue(id='nonexistent-id') - assert await request_queue_client.get() is None - with pytest.raises(ValueError, match='Request queue with id "nonexistent-id" does not exist.'): - await request_queue_client.update(name='test-update') - await request_queue_client.delete() - - -async def test_get(request_queue_client: RequestQueueClient) -> None: - await asyncio.sleep(0.1) - info = await request_queue_client.get() - assert info is not None - assert info.id == request_queue_client.id - assert info.accessed_at != info.created_at - - -async def test_update(request_queue_client: RequestQueueClient) -> None: - new_rq_name = 'test-update' - request = Request.from_url('https://apify.com') - await request_queue_client.add_request(request) - old_rq_info = await request_queue_client.get() - assert old_rq_info is not None - assert old_rq_info.name is not None - old_rq_directory = Path( - request_queue_client._memory_storage_client.request_queues_directory, - old_rq_info.name, - ) - new_rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - await asyncio.sleep(0.1) - updated_rq_info = await request_queue_client.update(name=new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_rq_info.created_at == updated_rq_info.created_at - assert old_rq_info.modified_at != updated_rq_info.modified_at - assert old_rq_info.accessed_at != updated_rq_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Request queue with name "test-update" already exists'): - await request_queue_client.update(name=new_rq_name) - - -async def test_delete(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - rq_info = await request_queue_client.get() - assert rq_info is not None - - rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, str(rq_info.name)) - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - - await request_queue_client.delete() - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - # Does not crash when called again - await request_queue_client.delete() - - -async def test_list_head(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - await request_queue_client.add_request(Request.from_url('https://example.com')) - list_head = await request_queue_client.list_head() - assert len(list_head.items) == 2 - - for item in list_head.items: - assert item.id is not None - - -async def test_request_state_serialization(request_queue_client: RequestQueueClient) -> None: - request = Request.from_url('https://crawlee.dev', payload=b'test') - request.state = RequestState.UNPROCESSED - - await request_queue_client.add_request(request) - - result = await request_queue_client.list_head() - assert len(result.items) == 1 - assert result.items[0] == request - - got_request = await request_queue_client.get_request(request.id) - - assert request == got_request - - -async def test_add_record(request_queue_client: RequestQueueClient) -> None: - processed_request_forefront = await request_queue_client.add_request( - Request.from_url('https://apify.com'), - forefront=True, - ) - processed_request_not_forefront = await request_queue_client.add_request( - Request.from_url('https://example.com'), - forefront=False, - ) - - assert processed_request_forefront.id is not None - assert processed_request_not_forefront.id is not None - assert processed_request_forefront.was_already_handled is False - assert processed_request_not_forefront.was_already_handled is False - - rq_info = await request_queue_client.get() - assert rq_info is not None - assert rq_info.pending_request_count == rq_info.total_request_count == 2 - assert rq_info.handled_request_count == 0 - - -async def test_get_record(request_queue_client: RequestQueueClient) -> None: - request_url = 'https://apify.com' - processed_request = await request_queue_client.add_request(Request.from_url(request_url)) - - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - assert request.url == request_url - - # Non-existent id - assert (await request_queue_client.get_request('non-existent id')) is None - - -async def test_update_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - - rq_info_before_update = await request_queue_client.get() - assert rq_info_before_update is not None - assert rq_info_before_update.pending_request_count == 1 - assert rq_info_before_update.handled_request_count == 0 - - request.handled_at = datetime.now(timezone.utc) - request_update_info = await request_queue_client.update_request(request) - - assert request_update_info.was_already_handled is False - - rq_info_after_update = await request_queue_client.get() - assert rq_info_after_update is not None - assert rq_info_after_update.pending_request_count == 0 - assert rq_info_after_update.handled_request_count == 1 - - -async def test_delete_record(request_queue_client: RequestQueueClient) -> None: - processed_request_pending = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='pending', - ), - ) - - processed_request_handled = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='handled', - handled_at=datetime.now(timezone.utc), - ), - ) - - rq_info_before_delete = await request_queue_client.get() - assert rq_info_before_delete is not None - assert rq_info_before_delete.pending_request_count == 1 - - await request_queue_client.delete_request(processed_request_pending.id) - rq_info_after_first_delete = await request_queue_client.get() - assert rq_info_after_first_delete is not None - assert rq_info_after_first_delete.pending_request_count == 0 - assert rq_info_after_first_delete.handled_request_count == 1 - - await request_queue_client.delete_request(processed_request_handled.id) - rq_info_after_second_delete = await request_queue_client.get() - assert rq_info_after_second_delete is not None - assert rq_info_after_second_delete.pending_request_count == 0 - assert rq_info_after_second_delete.handled_request_count == 0 - - # Does not crash when called again - await request_queue_client.delete_request(processed_request_pending.id) - - -async def test_forefront(request_queue_client: RequestQueueClient) -> None: - # this should create a queue with requests in this order: - # Handled: - # 2, 5, 8 - # Not handled: - # 7, 4, 1, 0, 3, 6 - for i in range(9): - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await request_queue_client.add_request( - Request.from_url( - url=request_url, - unique_key=str(i), - handled_at=datetime.now(timezone.utc) if was_handled else None, - ), - forefront=forefront, - ) - - # Check that the queue head (unhandled items) is in the right order - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['7', '4', '1', '0', '3', '6'] - - # Mark request #1 as handled - await request_queue_client.update_request( - Request.from_url( - url='http://example.com/1', - unique_key='1', - handled_at=datetime.now(timezone.utc), - ), - ) - # Move request #3 to forefront - await request_queue_client.update_request( - Request.from_url(url='http://example.com/3', unique_key='3'), - forefront=True, - ) - - # Check that the queue head (unhandled items) is in the right order after the updates - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['3', '7', '4', '0', '6'] - - -async def test_add_duplicate_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - processed_request_duplicate = await request_queue_client.add_request(Request.from_url('https://apify.com')) - - assert processed_request.id == processed_request_duplicate.id - assert processed_request_duplicate.was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py b/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py deleted file mode 100644 index fa10889f83..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueCollectionClient - - -@pytest.fixture -def request_queues_client(memory_storage_client: MemoryStorageClient) -> RequestQueueCollectionClient: - return memory_storage_client.request_queues() - - -async def test_get_or_create(request_queues_client: RequestQueueCollectionClient) -> None: - rq_name = 'test' - # A new request queue gets created - rq_info = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.name == rq_name - - # Another get_or_create call returns the same request queue - rq_existing = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.id == rq_existing.id - assert rq_info.name == rq_existing.name - assert rq_info.created_at == rq_existing.created_at - - -async def test_list(request_queues_client: RequestQueueCollectionClient) -> None: - assert (await request_queues_client.list()).count == 0 - rq_info = await request_queues_client.get_or_create(name='dataset') - rq_list = await request_queues_client.list() - assert rq_list.count == 1 - assert rq_list.items[0].name == rq_info.name - - # Test sorting behavior - newer_rq_info = await request_queues_client.get_or_create(name='newer-dataset') - rq_list_sorting = await request_queues_client.list() - assert rq_list_sorting.count == 2 - assert rq_list_sorting.items[0].name == rq_info.name - assert rq_list_sorting.items[1].name == newer_rq_info.name diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index f299aee08d..c12a68d3e9 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -1,156 +1,456 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -from datetime import datetime, timezone from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + + + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def dataset() -> AsyncGenerator[Dataset, None]: - dataset = await Dataset.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] + + +@pytest.fixture +async def dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[Dataset, None]: + """Fixture that provides a dataset instance for each test.""" + Dataset._cache_by_id.clear() + Dataset._cache_by_name.clear() + + dataset = await Dataset.open( + name='test_dataset', + storage_client=storage_client, + configuration=configuration, + ) + yield dataset await dataset.drop() -async def test_open() -> None: - default_dataset = await Dataset.open() - default_dataset_by_id = await Dataset.open(id=default_dataset.id) +async def test_open_creates_new_dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new dataset with proper metadata.""" + dataset = await Dataset.open( + name='new_dataset', + storage_client=storage_client, + configuration=configuration, + ) + + # Verify dataset properties + assert dataset.id is not None + assert dataset.name == 'new_dataset' + assert dataset.metadata.item_count == 0 + + await dataset.drop() + + +async def test_open_existing_dataset( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing dataset correctly.""" + # Open the same dataset again + reopened_dataset = await Dataset.open( + name=dataset.name, + storage_client=storage_client, + ) - assert default_dataset is default_dataset_by_id + # Verify dataset properties + assert dataset.id == reopened_dataset.id + assert dataset.name == reopened_dataset.name + assert dataset.metadata.item_count == reopened_dataset.metadata.item_count + + # Verify they are the same object (from cache) + assert id(dataset) == id(reopened_dataset) + + +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await Dataset.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) + + +async def test_push_data_single_item(dataset: Dataset) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset.push_data(item) + + # Verify item was stored + result = await dataset.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset: Dataset) -> None: + """Test pushing multiple items to the dataset.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Verify items were stored + result = await dataset.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset: Dataset) -> None: + """Test getting data from an empty dataset returns empty results.""" + result = await dataset.get_data() + + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_pagination(dataset: Dataset) -> None: + """Test getting data with offset and limit parameters for pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test offset + result = await dataset.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset: Dataset) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get items in descending order + result = await dataset.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset: Dataset) -> None: + """Test getting data with skip_empty option filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Get all items + result = await dataset.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) - dataset_name = 'dummy-name' - named_dataset = await Dataset.open(name=dataset_name) - assert default_dataset is not named_dataset - with pytest.raises(RuntimeError, match='Dataset with id "nonexistent-id" does not exist!'): - await Dataset.open(id='nonexistent-id') +async def test_iterate_items(dataset: Dataset) -> None: + """Test iterating over dataset items yields each item in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) - # Test that when you try to open a dataset by ID and you use a name of an existing dataset, - # it doesn't work - with pytest.raises(RuntimeError, match='Dataset with id "dummy-name" does not exist!'): - await Dataset.open(id='dummy-name') + # Iterate over all items + collected_items = [item async for item in dataset.iterate_items()] + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 -async def test_consistency_accross_two_clients() -> None: - dataset = await Dataset.open(name='my-dataset') - await dataset.push_data({'key': 'value'}) - dataset_by_id = await Dataset.open(id=dataset.id) - await dataset_by_id.push_data({'key2': 'value2'}) +async def test_iterate_items_with_options(dataset: Dataset) -> None: + """Test iterating with offset, limit and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_list_items(dataset: Dataset) -> None: + """Test that list_items returns all dataset items as a list.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get all items as a list + collected_items = await dataset.list_items() + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_list_items_with_options(dataset: Dataset) -> None: + """Test that list_items respects filtering options.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3}, # Item with missing 'name' field + {}, # Empty item + {'id': 5, 'name': 'Item 5'}, + ] + await dataset.push_data(items) + + # Test with offset and limit + collected_items = await dataset.list_items(offset=1, limit=2) + assert len(collected_items) == 2 + assert collected_items[0]['id'] == 2 + assert collected_items[1]['id'] == 3 + + # Test with descending order - skip empty items to avoid KeyError + collected_items = await dataset.list_items(desc=True, skip_empty=True) + + # Filter items that have an 'id' field + items_with_ids = [item for item in collected_items if 'id' in item] + id_values = [item['id'] for item in items_with_ids] + + # Verify the list is sorted in descending order + assert sorted(id_values, reverse=True) == id_values, f'IDs should be in descending order. Got {id_values}' + + # Verify key IDs are present and in the right order + if 5 in id_values and 3 in id_values: + assert id_values.index(5) < id_values.index(3), 'ID 5 should come before ID 3 in descending order' + + # Test with skip_empty + collected_items = await dataset.list_items(skip_empty=True) + assert len(collected_items) == 4 # Should skip the empty item + assert all(item != {} for item in collected_items) + + # Test with fields - manually filter since 'fields' parameter is not supported + # Get all items first + collected_items = await dataset.list_items() + assert len(collected_items) == 5 + + # Manually extract only the 'id' field from each item + filtered_items = [{key: item[key] for key in ['id'] if key in item} for item in collected_items] + + # Verify 'name' field is not present in any item + assert all('name' not in item for item in filtered_items) + + # Test clean functionality manually instead of using the clean parameter + # Get all items + collected_items = await dataset.list_items() + + # Manually filter out empty items as 'clean' would do + clean_items = [item for item in collected_items if item != {}] - assert (await dataset.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] - assert (await dataset_by_id.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] + assert len(clean_items) == 4 # Should have 4 non-empty items + assert all(item != {} for item in clean_items) + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a dataset removes it from cache and clears its data.""" + dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data + await dataset.push_data({'test': 'data'}) + + # Verify dataset exists in cache + assert dataset.id in Dataset._cache_by_id + if dataset.name: + assert dataset.name in Dataset._cache_by_name + + # Drop the dataset await dataset.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await dataset_by_id.drop() - - -async def test_same_references() -> None: - dataset1 = await Dataset.open() - dataset2 = await Dataset.open() - assert dataset1 is dataset2 - - dataset_name = 'non-default' - dataset_named1 = await Dataset.open(name=dataset_name) - dataset_named2 = await Dataset.open(name=dataset_name) - assert dataset_named1 is dataset_named2 - - -async def test_drop() -> None: - dataset1 = await Dataset.open() - await dataset1.drop() - dataset2 = await Dataset.open() - assert dataset1 is not dataset2 - - -async def test_export(dataset: Dataset) -> None: - expected_csv = 'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' - expected_json = [{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}, {'id': 2, 'test': 'test'}] - desired_item_count = 3 - await dataset.push_data([{'id': i, 'test': 'test'} for i in range(desired_item_count)]) - await dataset.export_to(key='dataset-csv', content_type='csv') - await dataset.export_to(key='dataset-json', content_type='json') - kvs = await KeyValueStore.open() - dataset_csv = await kvs.get_value(key='dataset-csv') - dataset_json = await kvs.get_value(key='dataset-json') - assert dataset_csv == expected_csv - assert dataset_json == expected_json - - -async def test_push_data(dataset: Dataset) -> None: - desired_item_count = 2000 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == desired_item_count - list_page = await dataset.get_data(limit=desired_item_count) - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 - - -async def test_push_data_empty(dataset: Dataset) -> None: - await dataset.push_data([]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 0 - - -async def test_push_data_singular(dataset: Dataset) -> None: - await dataset.push_data({'id': 1}) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 1 - list_page = await dataset.get_data() - assert list_page.items[0]['id'] == 1 - - -async def test_get_data(dataset: Dataset) -> None: # We don't test everything, that's done in memory storage tests - desired_item_count = 3 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - list_page = await dataset.get_data() - assert list_page.count == desired_item_count - assert list_page.desc is False - assert list_page.offset == 0 - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 + # Verify dataset was removed from cache + assert dataset.id not in Dataset._cache_by_id + if dataset.name: + assert dataset.name not in Dataset._cache_by_name -async def test_iterate_items(dataset: Dataset) -> None: - desired_item_count = 3 - idx = 0 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) + # Verify dataset is empty (by creating a new one with the same name) + new_dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + result = await new_dataset.get_data() + assert result.count == 0 + await new_dataset.drop() + + +async def test_export_to_json( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to JSON format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, + ) + + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to JSON + await dataset.export_to( + key='dataset_export.json', + content_type='json', + to_key_value_store_name='export_kvs', + ) + + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.json') + assert record is not None - async for item in dataset.iterate_items(): - assert item['id'] == idx - idx += 1 + # Verify content has all the items + assert '"id": 1' in record + assert '"id": 2' in record + assert '"id": 3' in record - assert idx == desired_item_count + await kvs.drop() -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() +async def test_export_to_csv( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to CSV format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, + ) - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to CSV + await dataset.export_to( + key='dataset_export.csv', + content_type='csv', + to_key_value_store_name='export_kvs', ) - dataset = Dataset.from_storage_object(storage_client, storage_object) + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.csv') + assert record is not None + + # Verify content has all the items + assert 'id,name' in record + assert '1,Item 1' in record + assert '2,Item 2' in record + assert '3,Item 3' in record + + await kvs.drop() + + +async def test_export_to_invalid_content_type(dataset: Dataset) -> None: + """Test exporting dataset with invalid content type raises error.""" + with pytest.raises(ValueError, match='Unsupported content type'): + await dataset.export_to( + key='invalid_export', + content_type='invalid', # type: ignore[call-overload] # Intentionally invalid content type + ) + + +async def test_large_dataset(dataset: Dataset) -> None: + """Test handling a large dataset with many items.""" + items = [{'id': i, 'value': f'value-{i}'} for i in range(100)] + await dataset.push_data(items) + + # Test that all items are retrieved + result = await dataset.get_data(limit=None) + assert result.count == 100 + assert result.total == 100 - assert dataset.id == storage_object.id - assert dataset.name == storage_object.name - assert dataset.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Test pagination with large datasets + result = await dataset.get_data(offset=50, limit=25) + assert result.count == 25 + assert result.offset == 50 + assert result.items[0]['id'] == 50 + assert result.items[-1]['id'] == 74 diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 955d483546..4c43225d31 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -1,229 +1,328 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -import asyncio -from datetime import datetime, timedelta, timezone -from itertools import chain, repeat -from typing import TYPE_CHECKING, cast -from unittest.mock import patch -from urllib.parse import urlparse +import json +from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.events import EventManager -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path - from crawlee._types import JsonSerializable - - -@pytest.fixture -async def mock_event_manager() -> AsyncGenerator[EventManager, None]: - async with EventManager(persist_state_interval=timedelta(milliseconds=50)) as event_manager: - with patch('crawlee.service_locator.get_event_manager', return_value=event_manager): - yield event_manager - - -async def test_open() -> None: - default_key_value_store = await KeyValueStore.open() - default_key_value_store_by_id = await KeyValueStore.open(id=default_key_value_store.id) - - assert default_key_value_store is default_key_value_store_by_id + from crawlee.storage_clients import StorageClient - key_value_store_name = 'dummy-name' - named_key_value_store = await KeyValueStore.open(name=key_value_store_name) - assert default_key_value_store is not named_key_value_store - with pytest.raises(RuntimeError, match='KeyValueStore with id "nonexistent-id" does not exist!'): - await KeyValueStore.open(id='nonexistent-id') - # Test that when you try to open a key-value store by ID and you use a name of an existing key-value store, - # it doesn't work - with pytest.raises(RuntimeError, match='KeyValueStore with id "dummy-name" does not exist!'): - await KeyValueStore.open(id='dummy-name') +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() -async def test_open_save_storage_object() -> None: - default_key_value_store = await KeyValueStore.open() + return FileSystemStorageClient() - assert default_key_value_store.storage_object is not None - assert default_key_value_store.storage_object.id == default_key_value_store.id +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] -async def test_consistency_accross_two_clients() -> None: - kvs = await KeyValueStore.open(name='my-kvs') - await kvs.set_value('key', 'value') - - kvs_by_id = await KeyValueStore.open(id=kvs.id) - await kvs_by_id.set_value('key2', 'value2') - - assert (await kvs.get_value('key')) == 'value' - assert (await kvs.get_value('key2')) == 'value2' - assert (await kvs_by_id.get_value('key')) == 'value' - assert (await kvs_by_id.get_value('key2')) == 'value2' +@pytest.fixture +async def kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[KeyValueStore, None]: + """Fixture that provides a key-value store instance for each test.""" + KeyValueStore._cache_by_id.clear() + KeyValueStore._cache_by_name.clear() + + kvs = await KeyValueStore.open( + name='test_kvs', + storage_client=storage_client, + configuration=configuration, + ) + yield kvs await kvs.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await kvs_by_id.drop() - -async def test_same_references() -> None: - kvs1 = await KeyValueStore.open() - kvs2 = await KeyValueStore.open() - assert kvs1 is kvs2 - - kvs_name = 'non-default' - kvs_named1 = await KeyValueStore.open(name=kvs_name) - kvs_named2 = await KeyValueStore.open(name=kvs_name) - assert kvs_named1 is kvs_named2 - - -async def test_drop() -> None: - kvs1 = await KeyValueStore.open() - await kvs1.drop() - kvs2 = await KeyValueStore.open() - assert kvs1 is not kvs2 - - -async def test_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-str', 'string') - await key_value_store.set_value('test-int', 123) - await key_value_store.set_value('test-dict', {'abc': '123'}) - str_value = await key_value_store.get_value('test-str') - int_value = await key_value_store.get_value('test-int') - dict_value = await key_value_store.get_value('test-dict') - non_existent_value = await key_value_store.get_value('test-non-existent') - assert str_value == 'string' - assert int_value == 123 - assert dict_value['abc'] == '123' - assert non_existent_value is None +async def test_open_creates_new_kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new key-value store with proper metadata.""" + kvs = await KeyValueStore.open( + name='new_kvs', + storage_client=storage_client, + configuration=configuration, + ) -async def test_for_each_key(key_value_store: KeyValueStore) -> None: - keys = [item.key async for item in key_value_store.iterate_keys()] - assert len(keys) == 0 + # Verify key-value store properties + assert kvs.id is not None + assert kvs.name == 'new_kvs' - for i in range(2001): - await key_value_store.set_value(str(i).zfill(4), i) - index = 0 - async for item in key_value_store.iterate_keys(): - assert item.key == str(index).zfill(4) - index += 1 - assert index == 2001 + await kvs.drop() -async def test_static_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - value = await key_value_store.get_value('test-static') - assert value == 'static' +async def test_open_existing_kvs( + kvs: KeyValueStore, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing key-value store correctly.""" + # Open the same key-value store again + reopened_kvs = await KeyValueStore.open( + name=kvs.name, + storage_client=storage_client, + ) + # Verify key-value store properties + assert kvs.id == reopened_kvs.id + assert kvs.name == reopened_kvs.name -async def test_get_public_url_raises_for_non_existing_key(key_value_store: KeyValueStore) -> None: - with pytest.raises(ValueError, match='was not found'): - await key_value_store.get_public_url('i-do-not-exist') + # Verify they are the same object (from cache) + assert id(kvs) == id(reopened_kvs) -async def test_get_public_url(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - public_url = await key_value_store.get_public_url('test-static') +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await KeyValueStore.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) - url = urlparse(public_url) - path = url.netloc if url.netloc else url.path - with open(path) as f: # noqa: ASYNC230 - content = await asyncio.to_thread(f.read) - assert content == 'static' +async def test_set_get_value(kvs: KeyValueStore) -> None: + """Test setting and getting a value from the key-value store.""" + # Set a value + test_key = 'test-key' + test_value = {'data': 'value', 'number': 42} + await kvs.set_value(test_key, test_value) + # Get the value + result = await kvs.get_value(test_key) + assert result == test_value -async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - value = await key_value_store.get_auto_saved_value('state', default_value) - assert value == default_value +async def test_get_value_nonexistent(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value returns None.""" + result = await kvs.get_value('nonexistent-key') + assert result is None -async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - value['hello'] = 'new_world' - value_one = await key_value_store.get_auto_saved_value(key_name) - assert value_one == {'hello': 'new_world'} +async def test_get_value_with_default(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value with a default value.""" + default_value = {'default': True} + result = await kvs.get_value('nonexistent-key', default_value=default_value) + assert result == default_value - value_one['hello'] = ['new_world'] - value_two = await key_value_store.get_auto_saved_value(key_name) - assert value_two == {'hello': ['new_world']} +async def test_set_value_with_content_type(kvs: KeyValueStore) -> None: + """Test setting a value with a specific content type.""" + test_key = 'test-json' + test_value = {'data': 'value', 'items': [1, 2, 3]} + await kvs.set_value(test_key, test_value, content_type='application/json') -async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001 - # This is not a realtime system and timing constrains can be hard to enforce. - # For the test to avoid flakiness it needs some time tolerance. - autosave_deadline_time = 1 - autosave_check_period = 0.01 + # Verify the value is retrievable + result = await kvs.get_value(test_key) + assert result == test_value - async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool: - """Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds.""" - deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time) - while datetime.now(tz=timezone.utc) < deadline: - await asyncio.sleep(autosave_check_period) - if await key_value_store.get_value(key) == expected_value: - return True - return False - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'}) +async def test_delete_value(kvs: KeyValueStore) -> None: + """Test deleting a value from the key-value store.""" + # Set a value first + test_key = 'delete-me' + test_value = 'value to delete' + await kvs.set_value(test_key, test_value) - value['hello'] = 'new_world' - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'}) + # Verify value exists + assert await kvs.get_value(test_key) == test_value + # Delete the value + await kvs.delete_value(test_key) -async def test_get_auto_saved_value_auto_save_race_conditions(key_value_store: KeyValueStore) -> None: - """Two parallel functions increment global variable obtained by `get_auto_saved_value`. + # Verify value is gone + assert await kvs.get_value(test_key) is None - Result should be incremented by 2. - Method `get_auto_saved_value` must be implemented in a way that prevents race conditions in such scenario. - Test creates situation where first `get_auto_saved_value` call to kvs gets delayed. Such situation can happen - and unless handled, it can cause race condition in getting the state value.""" - await key_value_store.set_value('state', {'counter': 0}) - sleep_time_iterator = chain(iter([0.5]), repeat(0)) +async def test_list_keys_empty_kvs(kvs: KeyValueStore) -> None: + """Test listing keys from an empty key-value store.""" + keys = await kvs.list_keys() + assert len(keys) == 0 - async def delayed_get_value(key: str, default_value: None) -> None: - await asyncio.sleep(next(sleep_time_iterator)) - return await KeyValueStore.get_value(key_value_store, key=key, default_value=default_value) - async def increment_counter() -> None: - state = cast('dict[str, int]', await key_value_store.get_auto_saved_value('state')) - state['counter'] += 1 +async def test_list_keys(kvs: KeyValueStore) -> None: + """Test listing keys from a key-value store with items.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + # List keys + keys = await kvs.list_keys() + + # Verify keys + assert len(keys) == 3 + key_names = [k.key for k in keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_list_keys_with_limit(kvs: KeyValueStore) -> None: + """Test listing keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + # List with limit + keys = await kvs.list_keys(limit=5) + assert len(keys) == 5 + + +async def test_list_keys_with_exclusive_start_key(kvs: KeyValueStore) -> None: + """Test listing keys with an exclusive start key.""" + # Add some items in a known order + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + await kvs.set_value('key4', 'value4') + await kvs.set_value('key5', 'value5') + + # Get all keys first to determine their order + all_keys = await kvs.list_keys() + all_key_names = [k.key for k in all_keys] + + if len(all_key_names) >= 3: + # Start from the second key + start_key = all_key_names[1] + keys = await kvs.list_keys(exclusive_start_key=start_key) + + # We should get all keys after the start key + expected_count = len(all_key_names) - all_key_names.index(start_key) - 1 + assert len(keys) == expected_count + + # First key should be the one after start_key + first_returned_key = keys[0].key + assert first_returned_key != start_key + assert all_key_names.index(first_returned_key) > all_key_names.index(start_key) + + +async def test_iterate_keys(kvs: KeyValueStore) -> None: + """Test iterating over keys in the key-value store.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + collected_keys = [key async for key in kvs.iterate_keys()] + + # Verify iteration result + assert len(collected_keys) == 3 + key_names = [k.key for k in collected_keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_iterate_keys_with_limit(kvs: KeyValueStore) -> None: + """Test iterating over keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + collected_keys = [key async for key in kvs.iterate_keys(limit=5)] + + # Verify iteration result + assert len(collected_keys) == 5 + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a key-value store removes it from cache and clears its data.""" + kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) - with patch.object(key_value_store, 'get_value', delayed_get_value): - tasks = [asyncio.create_task(increment_counter()), asyncio.create_task(increment_counter())] - await asyncio.gather(*tasks) + # Add some data + await kvs.set_value('test', 'data') - assert (await key_value_store.get_auto_saved_value('state'))['counter'] == 2 + # Verify key-value store exists in cache + assert kvs.id in KeyValueStore._cache_by_id + if kvs.name: + assert kvs.name in KeyValueStore._cache_by_name + # Drop the key-value store + await kvs.drop() -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # Verify key-value store was removed from cache + assert kvs.id not in KeyValueStore._cache_by_id + if kvs.name: + assert kvs.name not in KeyValueStore._cache_by_name - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', + # Verify key-value store is empty (by creating a new one with the same name) + new_kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - key_value_store = KeyValueStore.from_storage_object(storage_client, storage_object) - - assert key_value_store.id == storage_object.id - assert key_value_store.name == storage_object.name - assert key_value_store.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Attempt to get a previously stored value + result = await new_kvs.get_value('test') + assert result is None + await new_kvs.drop() + + +async def test_complex_data_types(kvs: KeyValueStore) -> None: + """Test storing and retrieving complex data types.""" + # Test nested dictionaries + nested_dict = { + 'level1': { + 'level2': { + 'level3': 'deep value', + 'numbers': [1, 2, 3], + }, + }, + 'array': [{'a': 1}, {'b': 2}], + } + await kvs.set_value('nested', nested_dict) + result = await kvs.get_value('nested') + assert result == nested_dict + + # Test lists + test_list = [1, 'string', True, None, {'key': 'value'}] + await kvs.set_value('list', test_list) + result = await kvs.get_value('list') + assert result == test_list + + +async def test_string_data(kvs: KeyValueStore) -> None: + """Test storing and retrieving string data.""" + # Plain string + await kvs.set_value('string', 'simple string') + result = await kvs.get_value('string') + assert result == 'simple string' + + # JSON string + json_string = json.dumps({'key': 'value'}) + await kvs.set_value('json_string', json_string) + result = await kvs.get_value('json_string') + assert result == json_string diff --git a/tests/unit/storages/test_request_manager_tandem.py b/tests/unit/storages/test_request_manager_tandem.py index d08ab57dc1..060484a136 100644 --- a/tests/unit/storages/test_request_manager_tandem.py +++ b/tests/unit/storages/test_request_manager_tandem.py @@ -54,7 +54,7 @@ async def test_basic_functionality(test_input: TestInput) -> None: request_queue = await RequestQueue.open() if test_input.request_manager_items: - await request_queue.add_requests_batched(test_input.request_manager_items) + await request_queue.add_requests(test_input.request_manager_items) mock_request_loader = create_autospec(RequestLoader, instance=True, spec_set=True) mock_request_loader.fetch_next_request.side_effect = lambda: test_input.request_loader_items.pop(0) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index cddba8ef99..d9a0f98470 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -1,367 +1,500 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone -from itertools import count from typing import TYPE_CHECKING -from unittest.mock import AsyncMock, MagicMock import pytest -from pydantic import ValidationError - -from crawlee import Request, service_locator -from crawlee._request import RequestState -from crawlee.storage_clients import MemoryStorageClient, StorageClient -from crawlee.storage_clients._memory import RequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - StorageMetadata, - UnprocessedRequest, -) + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, StorageClient from crawlee.storages import RequestQueue if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Sequence + from collections.abc import AsyncGenerator + from pathlib import Path + + + + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def request_queue() -> AsyncGenerator[RequestQueue, None]: - rq = await RequestQueue.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] + + +@pytest.fixture +async def rq( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[RequestQueue, None]: + """Fixture that provides a request queue instance for each test.""" + RequestQueue._cache_by_id.clear() + RequestQueue._cache_by_name.clear() + + rq = await RequestQueue.open( + name='test_request_queue', + storage_client=storage_client, + configuration=configuration, + ) + yield rq await rq.drop() -async def test_open() -> None: - default_request_queue = await RequestQueue.open() - default_request_queue_by_id = await RequestQueue.open(id=default_request_queue.id) +async def test_open_creates_new_rq( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new request queue with proper metadata.""" + rq = await RequestQueue.open( + name='new_request_queue', + storage_client=storage_client, + configuration=configuration, + ) - assert default_request_queue is default_request_queue_by_id + # Verify request queue properties + assert rq.id is not None + assert rq.name == 'new_request_queue' + assert rq.metadata.pending_request_count == 0 + assert rq.metadata.handled_request_count == 0 + assert rq.metadata.total_request_count == 0 - request_queue_name = 'dummy-name' - named_request_queue = await RequestQueue.open(name=request_queue_name) - assert default_request_queue is not named_request_queue + await rq.drop() - with pytest.raises(RuntimeError, match='RequestQueue with id "nonexistent-id" does not exist!'): - await RequestQueue.open(id='nonexistent-id') - # Test that when you try to open a request queue by ID and you use a name of an existing request queue, - # it doesn't work - with pytest.raises(RuntimeError, match='RequestQueue with id "dummy-name" does not exist!'): - await RequestQueue.open(id='dummy-name') +async def test_open_existing_rq( + rq: RequestQueue, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing request queue correctly.""" + # Open the same request queue again + reopened_rq = await RequestQueue.open( + name=rq.name, + storage_client=storage_client, + ) + # Verify request queue properties + assert rq.id == reopened_rq.id + assert rq.name == reopened_rq.name -async def test_consistency_accross_two_clients() -> None: - request_apify = Request.from_url('https://apify.com') - request_crawlee = Request.from_url('https://crawlee.dev') + # Verify they are the same object (from cache) + assert id(rq) == id(reopened_rq) - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_apify) - rq_by_id = await RequestQueue.open(id=rq.id) - await rq_by_id.add_request(request_crawlee) +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await RequestQueue.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) - assert await rq.get_total_count() == 2 - assert await rq_by_id.get_total_count() == 2 - assert await rq.fetch_next_request() == request_apify - assert await rq_by_id.fetch_next_request() == request_crawlee +async def test_add_request_string_url(rq: RequestQueue) -> None: + """Test adding a request with a string URL.""" + # Add a request with a string URL + url = 'https://example.com' + result = await rq.add_request(url) - await rq.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await rq_by_id.drop() + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_same_references() -> None: - rq1 = await RequestQueue.open() - rq2 = await RequestQueue.open() - assert rq1 is rq2 - rq_name = 'non-default' - rq_named1 = await RequestQueue.open(name=rq_name) - rq_named2 = await RequestQueue.open(name=rq_name) - assert rq_named1 is rq_named2 +async def test_add_request_object(rq: RequestQueue) -> None: + """Test adding a request object.""" + # Create and add a request object + request = Request.from_url(url='https://example.com', user_data={'key': 'value'}) + result = await rq.add_request(request) + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False -async def test_drop() -> None: - rq1 = await RequestQueue.open() - await rq1.drop() - rq2 = await RequestQueue.open() - assert rq1 is not rq2 + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_get_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - processed_request = await request_queue.add_request(request) - assert request.id == processed_request.id - request_2 = await request_queue.get_request(request.id) - assert request_2 is not None - assert request == request_2 +async def test_add_duplicate_request(rq: RequestQueue) -> None: + """Test adding a duplicate request to the queue.""" + # Add a request + url = 'https://example.com' + first_result = await rq.add_request(url) + # Add the same request again + second_result = await rq.add_request(url) -async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - assert await request_queue.is_empty() is True - add_request_info = await request_queue.add_request(request) + # Verify the second request was detected as duplicate + assert second_result.was_already_present is True + assert second_result.unique_key == first_result.unique_key - assert add_request_info.was_already_present is False - assert add_request_info.was_already_handled is False - assert await request_queue.is_empty() is False + # Verify the queue stats weren't incremented twice + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None - # Mark it as handled - next_request.handled_at = datetime.now(timezone.utc) - processed_request = await request_queue.mark_request_as_handled(next_request) +async def test_add_requests_batch(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch.""" + # Create a batch of requests + urls = [ + 'https://example.com/page1', + 'https://example.com/page2', + 'https://example.com/page3', + ] - assert processed_request is not None - assert processed_request.id == request.id - assert processed_request.unique_key == request.unique_key - assert await request_queue.is_finished() is True + # Add the requests + await rq.add_requests(urls) + # Wait for all background tasks to complete + await asyncio.sleep(0.1) -async def test_reclaim_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - await request_queue.add_request(request) + # Verify the queue stats + assert rq.metadata.total_request_count == 3 + assert rq.metadata.pending_request_count == 3 - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None - assert next_request.unique_key == request.url - - # Reclaim - await request_queue.reclaim_request(next_request) - # Try to fetch again after a few secs - await asyncio.sleep(4) # 3 seconds is the consistency delay in request queue - next_again = await request_queue.fetch_next_request() - - assert next_again is not None - assert next_again.id == request.id - assert next_again.unique_key == request.unique_key - - -@pytest.mark.parametrize( - 'requests', - [ - [Request.from_url('https://apify.com')], - ['https://crawlee.dev'], - [Request.from_url(f'https://example.com/{i}') for i in range(10)], - [f'https://example.com/{i}' for i in range(15)], - ], - ids=['single-request', 'single-url', 'multiple-requests', 'multiple-urls'], -) -async def test_add_batched_requests( - request_queue: RequestQueue, - requests: Sequence[str | Request], -) -> None: - request_count = len(requests) - # Add the requests to the RQ in batches - await request_queue.add_requests_batched(requests, wait_for_all_requests_to_be_added=True) +async def test_add_requests_batch_with_forefront(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch with forefront option.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + + # Add a batch of priority requests at the forefront - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == request_count + await rq.add_requests( + [ + 'https://example.com/priority1', + 'https://example.com/priority2', + 'https://example.com/priority3', + ], + forefront=True, + ) - # Fetch and validate each request in the queue - for original_request in requests: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Wait for all background tasks to complete + await asyncio.sleep(0.1) - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Fetch requests - they should come out in priority order first + next_request1 = await rq.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + next_request2 = await rq.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + next_request3 = await rq.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/priority') -async def test_invalid_user_data_serialization() -> None: - with pytest.raises(ValidationError): - Request.from_url( - 'https://crawlee.dev', - user_data={ - 'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc), - 'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)}, - }, - ) + # Now we should get the original requests + next_request4 = await rq.fetch_next_request() + assert next_request4 is not None + assert next_request4.url == 'https://example.com/page1' + next_request5 = await rq.fetch_next_request() + assert next_request5 is not None + assert next_request5.url == 'https://example.com/page2' -async def test_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url( - 'https://crawlee.dev', - user_data={ - 'hello': 'world', - 'foo': 42, - }, + # Queue should be empty now + next_request6 = await rq.fetch_next_request() + assert next_request6 is None + + +async def test_add_requests_mixed_forefront(rq: RequestQueue) -> None: + """Test the ordering when adding requests with mixed forefront values.""" + # Add normal requests + await rq.add_request('https://example.com/normal1') + await rq.add_request('https://example.com/normal2') + + # Add a batch with forefront=True + await rq.add_requests( + ['https://example.com/priority1', 'https://example.com/priority2'], + forefront=True, ) - await request_queue.add_request(request) + # Add another normal request + await rq.add_request('https://example.com/normal3') + + # Add another priority request + await rq.add_request('https://example.com/priority3', forefront=True) - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Wait for background tasks + await asyncio.sleep(0.1) - assert dequeued_request.user_data['hello'] == 'world' - assert dequeued_request.user_data['foo'] == 42 + # The expected order should be: + # 1. priority3 (most recent forefront) + # 2. priority1 (from batch, forefront) + # 3. priority2 (from batch, forefront) + # 4. normal1 (oldest normal) + # 5. normal2 + # 6. normal3 (newest normal) + requests = [] + while True: + req = await rq.fetch_next_request() + if req is None: + break + requests.append(req) + await rq.mark_request_as_handled(req) -async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url('https://crawlee.dev') - request.user_data['hello'] = 'world' - request.user_data['foo'] = 42 - request.crawlee_data.max_retries = 1 - request.crawlee_data.state = RequestState.ERROR_HANDLER + assert len(requests) == 6 + assert requests[0].url == 'https://example.com/priority3' - await request_queue.add_request(request) + # The next two should be from the forefront batch (exact order within batch may vary) + batch_urls = {requests[1].url, requests[2].url} + assert 'https://example.com/priority1' in batch_urls + assert 'https://example.com/priority2' in batch_urls - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Then the normal requests in order + assert requests[3].url == 'https://example.com/normal1' + assert requests[4].url == 'https://example.com/normal2' + assert requests[5].url == 'https://example.com/normal3' - data = dequeued_request.model_dump(by_alias=True) - assert data['userData']['hello'] == 'world' - assert data['userData']['foo'] == 42 - assert data['userData']['__crawlee'] == { - 'maxRetries': 1, - 'state': RequestState.ERROR_HANDLER, - } +async def test_add_requests_with_forefront(rq: RequestQueue) -> None: + """Test adding requests to the front of the queue.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') -async def test_deduplication_of_requests_with_custom_unique_key() -> None: - with pytest.raises(ValueError, match='`always_enqueue` cannot be used with a custom `unique_key`'): - Request.from_url('https://apify.com', unique_key='apify', always_enqueue=True) + # Add a priority request at the forefront + await rq.add_request('https://example.com/priority', forefront=True) + # Fetch the next request - should be the priority one + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/priority' -async def test_deduplication_of_requests_with_invalid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com', always_enqueue=True) - request_2 = Request.from_url('https://apify.com', always_enqueue=True) - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) +async def test_fetch_next_request_and_mark_handled(rq: RequestQueue) -> None: + """Test fetching and marking requests as handled.""" + # Add some requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') - assert await rq.get_total_count() == 2 + # Fetch first request + request1 = await rq.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/page1' - assert await rq.fetch_next_request() == request_1 - assert await rq.fetch_next_request() == request_2 + # Mark the request as handled + result = await rq.mark_request_as_handled(request1) + assert result is not None + assert result.was_already_handled is True + # Fetch next request + request2 = await rq.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/page2' -async def test_deduplication_of_requests_with_valid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://apify.com') + # Mark the second request as handled + await rq.mark_request_as_handled(request2) - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Verify counts + assert rq.metadata.total_request_count == 2 + assert rq.metadata.handled_request_count == 2 + assert rq.metadata.pending_request_count == 0 - assert await rq.get_total_count() == 1 + # Verify queue is empty + empty_request = await rq.fetch_next_request() + assert empty_request is None - assert await rq.fetch_next_request() == request_1 +async def test_get_request_by_id(rq: RequestQueue) -> None: + """Test retrieving a request by its ID.""" + # Add a request + added_result = await rq.add_request('https://example.com') + request_id = added_result.id -async def test_cache_requests(request_queue: RequestQueue) -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://crawlee.dev') + # Retrieve the request by ID + retrieved_request = await rq.get_request(request_id) + assert retrieved_request is not None + assert retrieved_request.id == request_id + assert retrieved_request.url == 'https://example.com' - await request_queue.add_request(request_1) - await request_queue.add_request(request_2) - assert request_queue._requests_cache.currsize == 2 +async def test_get_non_existent_request(rq: RequestQueue) -> None: + """Test retrieving a request that doesn't exist.""" + non_existent_request = await rq.get_request('non-existent-id') + assert non_existent_request is None - fetched_request = await request_queue.fetch_next_request() - assert fetched_request is not None - assert fetched_request.id == request_1.id +async def test_reclaim_request(rq: RequestQueue) -> None: + """Test reclaiming a request that failed processing.""" + # Add a request + await rq.add_request('https://example.com') - # After calling fetch_next_request request_1 moved to the end of the cache store. - cached_items = [request_queue._requests_cache.popitem()[0] for _ in range(2)] - assert cached_items == [request_2.id, request_1.id] + # Fetch the request + request = await rq.fetch_next_request() + assert request is not None + # Reclaim the request + result = await rq.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # Verify we can fetch it again + reclaimed_request = await rq.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + assert reclaimed_request.url == 'https://example.com' - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) - request_queue = RequestQueue.from_storage_object(storage_client, storage_object) - - assert request_queue.id == storage_object.id - assert request_queue.name == storage_object.name - assert request_queue.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] - - -async def test_add_batched_requests_with_retry(request_queue: RequestQueue) -> None: - """Test that unprocessed requests are retried. - - Unprocessed requests should not count in `get_total_count` - Test creates situation where in `batch_add_requests` call in first batch 3 requests are unprocessed. - On each following `batch_add_requests` call the last request in batch remains unprocessed. - In this test `batch_add_requests` is called once with batch of 10 requests. With retries only 1 request should - remain unprocessed.""" - - batch_add_requests_call_counter = count(start=1) - service_locator.get_storage_client() - initial_request_count = 10 - expected_added_requests = 9 - requests = [f'https://example.com/{i}' for i in range(initial_request_count)] - - class MockedRequestQueueClient(RequestQueueClient): - """Patched memory storage client that simulates unprocessed requests.""" - - async def _batch_add_requests_without_last_n( - self, batch: Sequence[Request], n: int = 0 - ) -> BatchRequestsOperationResponse: - response = await super().batch_add_requests(batch[:-n]) - response.unprocessed_requests = [ - UnprocessedRequest(url=r.url, unique_key=r.unique_key, method=r.method) for r in batch[-n:] - ] - return response - - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, # noqa: ARG002 - ) -> BatchRequestsOperationResponse: - """Mocked client behavior that simulates unprocessed requests. - - It processes all except last three at first run, then all except last none. - Overall if tried with the same batch it will process all except the last one. - """ - call_count = next(batch_add_requests_call_counter) - if call_count == 1: - # Process all but last three - return await self._batch_add_requests_without_last_n(requests, n=3) - # Process all but last - return await self._batch_add_requests_without_last_n(requests, n=1) - - mocked_storage_client = AsyncMock(spec=StorageClient) - mocked_storage_client.request_queue = MagicMock( - return_value=MockedRequestQueueClient(id='default', memory_storage_client=MemoryStorageClient.from_config()) +async def test_reclaim_request_with_forefront(rq: RequestQueue) -> None: + """Test reclaiming a request to the front of the queue.""" + # Add requests + await rq.add_request('https://example.com/first') + await rq.add_request('https://example.com/second') + + # Fetch the first request + first_request = await rq.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' + + # Reclaim it to the forefront + await rq.reclaim_request(first_request, forefront=True) + + # The reclaimed request should be returned first (before the second request) + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/first' + + +async def test_is_empty(rq: RequestQueue) -> None: + """Test checking if a request queue is empty.""" + # Initially the queue should be empty + assert await rq.is_empty() is True + + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_empty() is False + + # Fetch and handle the request + request = await rq.fetch_next_request() + + assert request is not None + await rq.mark_request_as_handled(request) + + # Queue should be empty again + assert await rq.is_empty() is True + + +async def test_is_finished(rq: RequestQueue) -> None: + """Test checking if a request queue is finished.""" + # Initially the queue should be finished (empty and no background tasks) + assert await rq.is_finished() is True + + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_finished() is False + + # Add requests in the background + await rq.add_requests( + ['https://example.com/1', 'https://example.com/2'], + wait_for_all_requests_to_be_added=False, ) - request_queue = RequestQueue(id='default', name='some_name', storage_client=mocked_storage_client) + # Queue shouldn't be finished while background tasks are running + assert await rq.is_finished() is False - # Add the requests to the RQ in batches - await request_queue.add_requests_batched( - requests, wait_for_all_requests_to_be_added=True, wait_time_between_batches=timedelta(0) + # Wait for background tasks to finish + await asyncio.sleep(0.2) + + # Process all requests + while True: + request = await rq.fetch_next_request() + if request is None: + break + await rq.mark_request_as_handled(request) + + # Now queue should be finished + assert await rq.is_finished() is True + + +async def test_mark_non_existent_request_as_handled(rq: RequestQueue) -> None: + """Test marking a non-existent request as handled.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to mark it as handled + result = await rq.mark_request_as_handled(request) + assert result is None + + +async def test_reclaim_non_existent_request(rq: RequestQueue) -> None: + """Test reclaiming a non-existent request.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to reclaim it + result = await rq.reclaim_request(request) + assert result is None + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a request queue removes it from cache and clears its data.""" + rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == expected_added_requests - # Fetch and validate each request in the queue - for original_request in requests[:expected_added_requests]: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Add a request + await rq.add_request('https://example.com') + + # Verify request queue exists in cache + assert rq.id in RequestQueue._cache_by_id + if rq.name: + assert rq.name in RequestQueue._cache_by_name - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Drop the request queue + await rq.drop() + + # Verify request queue was removed from cache + assert rq.id not in RequestQueue._cache_by_id + if rq.name: + assert rq.name not in RequestQueue._cache_by_name + + # Verify request queue is empty (by creating a new one with the same name) + new_rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # Verify the queue is empty + assert await new_rq.is_empty() is True + assert new_rq.metadata.total_request_count == 0 + assert new_rq.metadata.pending_request_count == 0 + await new_rq.drop() diff --git a/uv.lock b/uv.lock index 77035eff2c..e426c0fd9b 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ dev = [ { name = "pytest-only", specifier = "~=2.1.0" }, { name = "pytest-xdist", specifier = "~=3.6.0" }, { name = "ruff", specifier = "~=0.11.0" }, - { name = "setuptools", specifier = "~=79.0.0" }, + { name = "setuptools" }, { name = "sortedcontainers-stubs", specifier = "~=2.4.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0.20240229" }, { name = "types-cachetools", specifier = "~=5.5.0.20240820" }, diff --git a/website/generate_module_shortcuts.py b/website/generate_module_shortcuts.py index 5a18e8d3f3..61acc68ade 100755 --- a/website/generate_module_shortcuts.py +++ b/website/generate_module_shortcuts.py @@ -5,6 +5,7 @@ import importlib import inspect import json +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -55,5 +56,5 @@ def resolve_shortcuts(shortcuts: dict) -> None: resolve_shortcuts(shortcuts) -with open('module_shortcuts.json', 'w', encoding='utf-8') as shortcuts_file: +with Path('module_shortcuts.json').open('w', encoding='utf-8') as shortcuts_file: json.dump(shortcuts, shortcuts_file, indent=4, sort_keys=True)