diff --git a/docs/guides/code_examples/request_router/router_middleware.py b/docs/guides/code_examples/request_router/router_middleware.py
new file mode 100644
index 0000000000..3f9e04304c
--- /dev/null
+++ b/docs/guides/code_examples/request_router/router_middleware.py
@@ -0,0 +1,50 @@
+import asyncio
+import time
+
+from crawlee import Request
+from crawlee.crawlers import ParselCrawler, ParselCrawlingContext
+from crawlee.router import Router
+
+
+async def main() -> None:
+ # Create a custom router instance
+ router = Router[ParselCrawlingContext]()
+
+ # Register a middleware that logs every request before it reaches a handler
+ @router.use
+ async def logging_middleware(context: ParselCrawlingContext) -> None:
+ context.log.info(
+ f'Processing request: {context.request.url} label={context.request.label}'
+ )
+
+ # Register a middleware that adds a timestamp to the request's user data
+ @router.use
+ async def timestamp_middleware(context: ParselCrawlingContext) -> None:
+ context.request.user_data['start_time'] = time.monotonic()
+
+ @router.default_handler
+ async def default_handler(context: ParselCrawlingContext) -> None:
+ context.log.info(f'Processing {context.request.url} with default handler')
+
+ @router.handler('CATEGORY')
+ async def category_handler(context: ParselCrawlingContext) -> None:
+ context.log.info(f'Processing {context.request.url} with category handler')
+
+ crawler = ParselCrawler(
+ request_handler=router,
+ max_requests_per_crawl=10,
+ )
+
+ await crawler.run(
+ [
+ 'https://warehouse-theme-metal.myshopify.com/',
+ Request.from_url(
+ 'https://warehouse-theme-metal.myshopify.com/collections/all',
+ label='CATEGORY',
+ ),
+ ]
+ )
+
+
+if __name__ == '__main__':
+ asyncio.run(main())
diff --git a/docs/guides/request_router.mdx b/docs/guides/request_router.mdx
index d9d7733abf..1aceab2637 100644
--- a/docs/guides/request_router.mdx
+++ b/docs/guides/request_router.mdx
@@ -1,7 +1,7 @@
---
id: request-router
title: Request router
-description: Learn how to use the Router class to organize request handlers, error handlers, and pre-navigation hooks in Crawlee.
+description: Learn how to use the Router class to organize request handlers, middleware, error handlers, and pre-navigation hooks in Crawlee.
---
import ApiLink from '@site/src/components/ApiLink';
@@ -16,6 +16,7 @@ import ErrorHandler from '!!raw-loader!roa-loader!./code_examples/request_router
import FailedRequestHandler from '!!raw-loader!roa-loader!./code_examples/request_router/failed_request_handler.py';
import PlaywrightPreNavigation from '!!raw-loader!roa-loader!./code_examples/request_router/playwright_pre_navigation.py';
import AdaptiveCrawlerHandlers from '!!raw-loader!roa-loader!./code_examples/request_router/adaptive_crawler_handlers.py';
+import RouterMiddleware from '!!raw-loader!roa-loader!./code_examples/request_router/router_middleware.py';
The `Router` class manages request flow and coordinates the execution of user-defined logic in Crawlee projects. It routes incoming requests to appropriate user-defined handlers based on labels, manages error scenarios, and provides hooks for pre-navigation execution. The `Router` serves as the orchestrator for all crawling operations, ensuring that each request is processed by the correct handler according to its type and label.
@@ -57,6 +58,14 @@ More complex crawling projects often require different processing logic for vari
{BasicRequestHandlers}
+## Middleware
+
+Middlewares are functions registered with `router.use()` that execute before the matched request handler on every request, regardless of the request label. Multiple middlewares can be registered and are executed sequentially in the order they were registered. If a middleware raises an exception, the execution chain is interrupted and the handler is not called.
+
+
+ {RouterMiddleware}
+
+
## Error handlers
Crawlee provides error handling mechanisms to manage request processing failures. It distinguishes between recoverable errors that may succeed on retry and permanent failures that require alternative handling strategies.
@@ -107,6 +116,6 @@ The `AdaptivePlaywrightCrawler``Router` class and how to organize your crawling logic. You learned how to use built-in and custom routers, implement request handlers with label-based routing, handle errors with error and failed request handlers, and configure pre-navigation hooks for different crawler types.
+This guide introduced you to the `Router` class and how to organize your crawling logic. You learned how to use built-in and custom routers, implement request handlers with label-based routing, add middleware with `router.use()`, handle errors with error and failed request handlers, and configure pre-navigation hooks for different crawler types.
If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping!
diff --git a/src/crawlee/router.py b/src/crawlee/router.py
index 6d72aa9bf7..e1e0cb613d 100644
--- a/src/crawlee/router.py
+++ b/src/crawlee/router.py
@@ -32,6 +32,10 @@ class Router(Generic[TCrawlingContext]):
router = Router[HttpCrawlingContext]()
+ # Middleware executed for every request before the handlers
+ @router.use
+ async def logging_middleware(context: HttpCrawlingContext) -> None:
+ context.log.info(f'Processing request: {context.request.url} label={context.request.label}')
# Handler for requests without a matching label handler
@router.default_handler
@@ -59,6 +63,7 @@ async def main() -> None:
def __init__(self) -> None:
self._default_handler: RequestHandler[TCrawlingContext] | None = None
self._handlers_by_label = dict[str, RequestHandler[TCrawlingContext]]()
+ self._middlewares = list[RequestHandler[TCrawlingContext]]()
def default_handler(self: Router, handler: RequestHandler[TCrawlingContext]) -> RequestHandler[TCrawlingContext]:
"""Register a default request handler.
@@ -91,8 +96,19 @@ def wrapper(handler: Callable[[TCrawlingContext], Awaitable]) -> Callable[[TCraw
return wrapper
+ def use(self, middleware: RequestHandler[TCrawlingContext]) -> RequestHandler[TCrawlingContext]:
+ """Register a middleware.
+
+ A middleware is a function that is executed before the request handler.
+ """
+ self._middlewares.append(middleware)
+ return middleware
+
async def __call__(self, context: TCrawlingContext) -> None:
"""Invoke a request handler that matches the request label (or the default)."""
+ for middleware in self._middlewares:
+ await middleware(context)
+
context.request.state = RequestState.REQUEST_HANDLER
if context.request.label is None or context.request.label not in self._handlers_by_label:
if self._default_handler is None:
diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py
index c87a44d323..1b36f9a013 100644
--- a/tests/unit/test_router.py
+++ b/tests/unit/test_router.py
@@ -116,3 +116,114 @@ async def handler(_context: MockContext) -> None:
await router(MockContext(label='B'))
mock_handler.assert_called_with('B')
assert mock_handler.call_count == 2
+
+
+async def test_router_use_middleware() -> None:
+ router = Router[MockContext]()
+ mock_middleware = Mock()
+ mock_default_handler = Mock()
+
+ @router.use
+ async def middleware_1(_context: MockContext) -> None:
+ mock_middleware(call='middleware_1')
+
+ @router.use
+ async def middleware_2(_context: MockContext) -> None:
+ mock_middleware(call='middleware_2')
+
+ @router.default_handler
+ async def default_handler(_context: MockContext) -> None:
+ mock_default_handler()
+
+ await router(MockContext(label=None))
+
+ assert mock_middleware.call_count == 2
+ mock_middleware.assert_any_call(call='middleware_1')
+ mock_middleware.assert_any_call(call='middleware_2')
+ mock_default_handler.assert_called_once()
+ # Check order of middleware execution
+ assert mock_middleware.call_args_list[0][1] == {'call': 'middleware_1'}
+ assert mock_middleware.call_args_list[1][1] == {'call': 'middleware_2'}
+
+
+async def test_router_use_middleware_with_label() -> None:
+ router = Router[MockContext]()
+ mock_middleware = Mock()
+ mock_handler = Mock()
+
+ @router.use
+ async def middleware_1(_context: MockContext) -> None:
+ mock_middleware(call='middleware_1')
+
+ @router.use
+ async def middleware_2(_context: MockContext) -> None:
+ mock_middleware(call='middleware_2')
+
+ @router.handler('A')
+ async def handler_a(_context: MockContext) -> None:
+ mock_handler(call='handler_a')
+
+ @router.default_handler
+ async def default_handler(_context: MockContext) -> None:
+ mock_handler(call='default_handler')
+
+ await router(MockContext(label='A'))
+ await router(MockContext(label=None))
+
+ assert mock_middleware.call_count == 4
+ assert mock_handler.call_count == 2
+
+ assert mock_middleware.call_args_list[0][1] == {'call': 'middleware_1'}
+ assert mock_middleware.call_args_list[1][1] == {'call': 'middleware_2'}
+ assert mock_handler.call_args_list[0][1] == {'call': 'handler_a'}
+ assert mock_middleware.call_args_list[2][1] == {'call': 'middleware_1'}
+ assert mock_middleware.call_args_list[3][1] == {'call': 'middleware_2'}
+ assert mock_handler.call_args_list[1][1] == {'call': 'default_handler'}
+
+
+async def test_router_middleware_order_execution() -> None:
+ router = Router[MockContext]()
+ mock_execution_order = Mock()
+
+ @router.use
+ async def middleware_1(_context: MockContext) -> None:
+ mock_execution_order(call='middleware_1')
+
+ @router.use
+ async def middleware_2(_context: MockContext) -> None:
+ mock_execution_order(call='middleware_2')
+
+ @router.default_handler
+ async def default_handler(_context: MockContext) -> None:
+ mock_execution_order(call='default_handler')
+
+ await router(MockContext(label=None))
+
+ assert mock_execution_order.call_count == 3
+ assert mock_execution_order.call_args_list[0][1] == {'call': 'middleware_1'}
+ assert mock_execution_order.call_args_list[1][1] == {'call': 'middleware_2'}
+ assert mock_execution_order.call_args_list[2][1] == {'call': 'default_handler'}
+
+
+async def test_router_middleware_exception_interrupts_chain() -> None:
+ router = Router[MockContext]()
+ mock_execution_order = Mock()
+
+ @router.use
+ async def middleware_1(_context: MockContext) -> None:
+ mock_execution_order(call='middleware_1')
+ raise ValueError('middleware error')
+
+ @router.use
+ async def middleware_2(_context: MockContext) -> None:
+ mock_execution_order(call='middleware_2')
+
+ @router.default_handler
+ async def default_handler(_context: MockContext) -> None:
+ mock_execution_order(call='default_handler')
+
+ with pytest.raises(ValueError, match='middleware error'):
+ await router(MockContext(label=None))
+
+ assert mock_execution_order.call_count == 1
+ assert mock_execution_order.call_args_list[0][1] == {'call': 'middleware_1'}