Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
"pydantic>=2.11.0",
"pyee>=9.0.0",
"tldextract>=5.1.0",
"typing-extensions>=4.1.0",
"typing-extensions>=4.10.0",
"yarl>=1.18.0",
]

Expand Down
9 changes: 5 additions & 4 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Iterator, MutableMapping
from collections.abc import Iterator, Mapping, MutableMapping
from datetime import datetime
from enum import IntEnum
from typing import TYPE_CHECKING, Annotated, Any, TypedDict, cast
Expand Down Expand Up @@ -135,7 +135,7 @@ class RequestOptions(TypedDict):
keep_url_fragment: NotRequired[bool]
use_extended_unique_key: NotRequired[bool]
always_enqueue: NotRequired[bool]
user_data: NotRequired[dict[str, JsonSerializable]]
user_data: NotRequired[Mapping[str, JsonSerializable]]
no_retry: NotRequired[bool]
enqueue_strategy: NotRequired[EnqueueStrategy]
max_retries: NotRequired[int | None]
Expand Down Expand Up @@ -200,7 +200,7 @@ class Request(BaseModel):
headers: HttpHeaders = HttpHeaders()
"""HTTP request headers."""

user_data: dict[str, JsonSerializable] = {}
user_data: MutableMapping[str, JsonSerializable] = {}
"""Custom user data assigned to the request. Use this to save any request related data to the
request's scope, keeping them accessible on retries, failures etc.
"""
Expand All @@ -209,8 +209,9 @@ class Request(BaseModel):
headers: Annotated[HttpHeaders, Field(default_factory=HttpHeaders)]
"""HTTP request headers."""

# Internally, the model contains `UserData`, this is just for convenience
user_data: Annotated[
dict[str, JsonSerializable], # Internally, the model contains `UserData`, this is just for convenience
MutableMapping[str, JsonSerializable],
Field(alias='userData', default_factory=UserData),
PlainValidator(user_data_adapter.validate_python),
PlainSerializer(
Expand Down
22 changes: 10 additions & 12 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import logging
import re
from collections.abc import Awaitable, Coroutine, Sequence
from collections.abc import Awaitable, Coroutine, MutableMapping, Sequence

from typing_extensions import NotRequired, Required, Self, Unpack

Expand All @@ -27,9 +27,7 @@
from crawlee.storage_clients import StorageClient
from crawlee.storages import KeyValueStore

# Workaround for https://github.com/pydantic/pydantic/issues/9445
J = TypeVar('J', bound='JsonSerializable')
JsonSerializable = list[J] | dict[str, J] | str | bool | int | float | None
JsonSerializable = dict[str, 'JsonSerializable'] | list['JsonSerializable'] | str | int | float | bool | None
else:
from pydantic import JsonValue as JsonSerializable
Comment thread
vdusek marked this conversation as resolved.

Expand Down Expand Up @@ -198,7 +196,7 @@ class PushDataKwargs(TypedDict):


class PushDataFunctionCall(PushDataKwargs):
data: list[dict[str, Any]] | dict[str, Any]
data: Sequence[Mapping[str, JsonSerializable]] | Mapping[str, JsonSerializable]
dataset_id: str | None
dataset_name: str | None
dataset_alias: str | None
Expand Down Expand Up @@ -300,7 +298,7 @@ async def add_requests(

async def push_data(
self,
data: list[dict[str, Any]] | dict[str, Any],
data: Sequence[Mapping[str, JsonSerializable]] | Mapping[str, JsonSerializable],
dataset_id: str | None = None,
dataset_name: str | None = None,
dataset_alias: str | None = None,
Expand Down Expand Up @@ -392,7 +390,7 @@ def __call__(
selector: str | None = None,
attribute: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
rq_id: str | None = None,
rq_name: str | None = None,
Expand All @@ -417,7 +415,7 @@ def __call__(
selector: str | None = None,
attribute: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
requests: Sequence[str | Request] | None = None,
rq_id: str | None = None,
Expand Down Expand Up @@ -465,7 +463,7 @@ def __call__(
selector: str = 'a',
attribute: str = 'href',
label: str | None = None,
user_data: dict[str, Any] | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, list[Request]]:
Expand Down Expand Up @@ -543,7 +541,7 @@ class PushDataFunction(Protocol):

def __call__(
self,
data: list[dict[str, Any]] | dict[str, Any],
data: Sequence[Mapping[str, JsonSerializable]] | Mapping[str, JsonSerializable],
dataset_id: str | None = None,
dataset_name: str | None = None,
dataset_alias: str | None = None,
Expand Down Expand Up @@ -616,8 +614,8 @@ class UseStateFunction(Protocol):

def __call__(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> Coroutine[None, None, dict[str, JsonSerializable]]:
default_value: MutableMapping[str, JsonSerializable] | None = None,
) -> Coroutine[None, None, MutableMapping[str, JsonSerializable]]:
"""Call dunder method.

Args:
Expand Down
8 changes: 4 additions & 4 deletions src/crawlee/_utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import TYPE_CHECKING, overload

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Mapping
from typing import Any, TextIO

from typing_extensions import Unpack

from crawlee._types import ExportDataCsvKwargs, ExportDataJsonKwargs
from crawlee._types import ExportDataCsvKwargs, ExportDataJsonKwargs, JsonSerializable

if sys.platform == 'win32':

Expand Down Expand Up @@ -150,7 +150,7 @@ async def atomic_write(


async def export_json_to_stream(
iterator: AsyncIterator[dict[str, Any]],
iterator: AsyncIterator[Mapping[str, JsonSerializable]],
Comment thread
vdusek marked this conversation as resolved.
dst: TextIO,
**kwargs: Unpack[ExportDataJsonKwargs],
) -> None:
Expand All @@ -159,7 +159,7 @@ async def export_json_to_stream(


async def export_csv_to_stream(
iterator: AsyncIterator[dict[str, Any]],
iterator: AsyncIterator[Mapping[str, JsonSerializable]],
dst: TextIO,
**kwargs: Unpack[ExportDataCsvKwargs],
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Generic
from typing import TYPE_CHECKING, Generic

from more_itertools import partition
from pydantic import ValidationError
Expand All @@ -21,12 +21,12 @@
from ._http_crawling_context import HttpCrawlingContext, ParsedHttpCrawlingContext, TParseResult, TSelectResult

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator, Mapping

from typing_extensions import Unpack

from crawlee import RequestTransformAction
from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, ExtractLinksFunction
from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, ExtractLinksFunction, JsonSerializable

from ._abstract_http_parser import AbstractHttpParser

Expand Down Expand Up @@ -200,7 +200,7 @@ async def extract_links(
selector: str = 'a',
attribute: str = 'href',
label: str | None = None,
user_data: dict[str, Any] | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ._result_comparator import create_default_comparator

if TYPE_CHECKING:
from collections.abc import MutableMapping
from types import TracebackType

from typing_extensions import Unpack
Expand Down Expand Up @@ -286,7 +287,7 @@ async def _crawl_one(
self,
rendering_type: RenderingType,
context: BasicCrawlingContext,
state: dict[str, JsonSerializable] | None = None,
state: MutableMapping[str, JsonSerializable] | None = None,
) -> SubCrawlerRun:
"""Perform a one request crawl with specific context pipeline and return `SubCrawlerRun`.

Expand All @@ -297,8 +298,8 @@ async def _crawl_one(
if state is not None:

async def get_input_state(
default_value: dict[str, JsonSerializable] | None = None, # noqa:ARG001 # Intentionally unused arguments. Closure, that generates same output regardless of inputs.
) -> dict[str, JsonSerializable]:
default_value: MutableMapping[str, JsonSerializable] | None = None, # noqa:ARG001 # Intentionally unused arguments. Closure, that generates same output regardless of inputs.
) -> MutableMapping[str, JsonSerializable]:
return state

use_state_function = get_input_state
Expand Down Expand Up @@ -411,8 +412,10 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
# avoid static crawl to modify the state.
# (This static crawl is performed only to evaluate rendering type detection.)
kvs = await context.get_key_value_store()
default_value = dict[str, JsonSerializable]()
old_state: dict[str, JsonSerializable] = await kvs.get_value(self._CRAWLEE_STATE_KEY, default_value)
default_value: MutableMapping[str, JsonSerializable] = {}
old_state: MutableMapping[str, JsonSerializable] = await kvs.get_value(
self._CRAWLEE_STATE_KEY, default_value
)
old_state_copy = deepcopy(old_state)

pw_run = await self._crawl_one('client only', context=context)
Expand Down
10 changes: 5 additions & 5 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

if TYPE_CHECKING:
import re
from collections.abc import Iterator
from collections.abc import Iterator, Mapping, MutableMapping
from contextlib import AbstractAsyncContextManager

from crawlee._types import (
Expand Down Expand Up @@ -856,8 +856,8 @@ async def add_requests(

async def use_state(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> dict[str, JsonSerializable]:
default_value: MutableMapping[str, JsonSerializable] | None = None,
) -> MutableMapping[str, JsonSerializable]:
kvs = await self.get_key_value_store()
return await kvs.get_auto_saved_value(f'{self._CRAWLEE_STATE_KEY}_{self._id}', default_value)

Expand Down Expand Up @@ -941,7 +941,7 @@ async def export_data(

async def _push_data(
self,
data: list[dict[str, Any]] | dict[str, Any],
data: Sequence[Mapping[str, JsonSerializable]] | Mapping[str, JsonSerializable],
Comment thread
vdusek marked this conversation as resolved.
dataset_id: str | None = None,
dataset_name: str | None = None,
dataset_alias: str | None = None,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ async def enqueue_links(
selector: str | None = None,
attribute: str | None = None,
label: str | None = None,
user_data: dict[str, Any] | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
requests: Sequence[str | Request] | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/crawlee/crawlers/_playwright/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
HttpHeaders,
HttpMethod,
HttpPayload,
JsonSerializable,
)
from crawlee.browsers._types import BrowserType

Expand Down Expand Up @@ -384,7 +385,7 @@ async def extract_links(
selector: str = 'a',
attribute: str = 'href',
label: str | None = None,
user_data: dict | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
Expand Down
5 changes: 4 additions & 1 deletion src/crawlee/sessions/_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import MutableMapping
from datetime import datetime, timedelta
from typing import Annotated, Any

Expand All @@ -13,6 +14,8 @@
computed_field,
)

from crawlee._types import JsonSerializable

from ._cookies import CookieParam
from ._session import Session

Expand All @@ -24,7 +27,7 @@ class SessionModel(BaseModel):

id: Annotated[str, Field(alias='id')]
max_age: Annotated[timedelta, Field(alias='maxAge')]
user_data: Annotated[dict, Field(alias='userData')]
user_data: Annotated[MutableMapping[str, JsonSerializable], Field(alias='userData')]
max_error_score: Annotated[float, Field(alias='maxErrorScore')]
error_score_decrement: Annotated[float, Field(alias='errorScoreDecrement')]
created_at: Annotated[datetime, Field(alias='createdAt')]
Expand Down
8 changes: 5 additions & 3 deletions src/crawlee/sessions/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from crawlee.sessions._cookies import CookieParam, SessionCookies

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping
from http.cookiejar import CookieJar

from crawlee._types import JsonSerializable
from crawlee.sessions._models import SessionModel

logger = getLogger(__name__)
Expand All @@ -36,7 +38,7 @@ def __init__(
*,
id: str | None = None,
max_age: timedelta = timedelta(minutes=50),
user_data: dict | None = None,
user_data: Mapping[str, JsonSerializable] | None = None,
max_error_score: float = 3.0,
error_score_decrement: float = 0.5,
created_at: datetime | None = None,
Expand All @@ -63,7 +65,7 @@ def __init__(
"""
self._id = id or crypto_random_object_id(length=10)
self._max_age = max_age
self._user_data = user_data or {}
self._user_data: MutableMapping[str, JsonSerializable] = dict(user_data) if user_data is not None else {}
self._max_error_score = max_error_score
self._error_score_decrement = error_score_decrement
self._created_at = created_at or datetime.now(timezone.utc)
Expand Down Expand Up @@ -117,7 +119,7 @@ def id(self) -> str:
return self._id

@property
def user_data(self) -> dict:
def user_data(self) -> MutableMapping[str, JsonSerializable]:
"""Get the user data."""
return self._user_data

Expand Down
Loading
Loading