diff --git a/docs/release-notes/whats-new-3.rst b/docs/release-notes/whats-new-3.rst index ca0ad2c782..d413410fb4 100644 --- a/docs/release-notes/whats-new-3.rst +++ b/docs/release-notes/whats-new-3.rst @@ -204,3 +204,24 @@ Change the media type of :attr:`~enums.MediaType.MESSAGEPACK` and newly introduced official ``application/vnd.msgpack``. https://www.iana.org/assignments/media-types/application/vnd.msgpack + + +Deprecated ``resolve_`` methods on route handlers +------------------------------------------------- + +All ``resolve_`` methods on the route handlers +(e.g. :meth:`~litestar.handlers.HTTPRouteHandler.resolve_response_headers`) have been +deprecated and will be removed in ``4.0``. The attributes can now safely be accessed +directly (e.g. `HTTPRouteHandlers.response_headers`). + + +Moved routing related methods from ``Router`` to ``Litestar`` +------------------------------------------------------------- + +:class:`~litestar.router.Router` now only holds route handlers and configuration, while +the actual routing is done in :class:`~litestar.app.Litestar`. With this, several +methods and properties have been moved from ``Router`` to ``Litestar``: + +- ``route_handler_method_map`` +- ``get_route_handler_map`` +- ``routes`` diff --git a/litestar/_asgi/asgi_router.py b/litestar/_asgi/asgi_router.py index 7add9c0635..5cec9be3f1 100644 --- a/litestar/_asgi/asgi_router.py +++ b/litestar/_asgi/asgi_router.py @@ -52,6 +52,7 @@ class ASGIRouter: "_plain_routes", "_registered_routes", "_static_routes", + "_trie_initialized", "app", "root_route_map_node", "route_handler_index", @@ -69,6 +70,7 @@ def __init__(self, app: Litestar) -> None: self._mount_routes: dict[str, RouteTrieNode] = {} self._plain_routes: set[str] = set() self._registered_routes: set[HTTPRoute | WebSocketRoute | ASGIRoute] = set() + self._trie_initialized = False self.app = app self.root_route_map_node: RouteTrieNode = create_node() self.route_handler_index: dict[str, RouteHandlerType] = {} @@ -94,7 +96,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ScopeState.from_scope(scope).exception_handlers = self._app_exception_handlers raise else: - ScopeState.from_scope(scope).exception_handlers = route_handler.resolve_exception_handlers() + ScopeState.from_scope(scope).exception_handlers = route_handler.exception_handlers scope["route_handler"] = route_handler scope["path_template"] = path_template await asgi_app(scope, receive, send) @@ -145,8 +147,16 @@ def construct_routing_trie(self) -> None: This map is used in the asgi router to route requests. """ - new_routes = [route for route in self.app.routes if route not in self._registered_routes] - for route in new_routes: + if self._trie_initialized: # pragma: no cover + self._mount_paths_regex = None + self._mount_routes = {} + self._plain_routes = set() + self._registered_routes = set() + self.root_route_map_node = create_node() + self.route_handler_index = {} + self.route_mapping = defaultdict(list) + + for route in self.app.routes: add_route_to_trie( app=self.app, mount_routes=self._mount_routes, diff --git a/litestar/_asgi/routing_trie/mapping.py b/litestar/_asgi/routing_trie/mapping.py index 319a821041..74140e720c 100644 --- a/litestar/_asgi/routing_trie/mapping.py +++ b/litestar/_asgi/routing_trie/mapping.py @@ -189,7 +189,7 @@ def build_route_middleware_stack( from litestar.routes import HTTPRoute asgi_handler: ASGIApp = route.handle # type: ignore[assignment] - handler_middleware = route_handler.resolve_middleware() + handler_middleware = route_handler.middleware has_cached_route = isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers) has_middleware = ( app.csrf_config or app.compression_config or has_cached_route or app.allowed_hosts or handler_middleware diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index e5a7737125..992c2b696f 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -343,7 +343,7 @@ async def _extract_multipart( stream=connection.stream(), boundary=connection.content_type[-1].get("boundary", "").encode(), multipart_form_part_limit=multipart_form_part_limit, - type_decoders=connection.route_handler.resolve_type_decoders(), + type_decoders=connection.route_handler.type_decoders, ) else: form_values = scope_state.form diff --git a/litestar/_layers/utils.py b/litestar/_layers/utils.py index 61afd61c1d..ab5a278eec 100644 --- a/litestar/_layers/utils.py +++ b/litestar/_layers/utils.py @@ -12,18 +12,18 @@ from litestar.types.composite_types import ResponseCookies, ResponseHeaders -def narrow_response_headers(headers: ResponseHeaders | None) -> Sequence[ResponseHeader] | None: - """Given :class:`.types.ResponseHeaders` as a :class:`typing.Mapping`, create a list of +def narrow_response_headers(headers: ResponseHeaders | None) -> Sequence[ResponseHeader]: + """Given :class:`.types.ResponseHeaders` as a :class:`typing.Mapping`, create a tuple of :class:`.datastructures.response_header.ResponseHeader` from it, otherwise return ``headers`` unchanged """ return ( tuple(ResponseHeader(name=name, value=value) for name, value in headers.items()) if isinstance(headers, Mapping) else headers - ) + ) or () -def narrow_response_cookies(cookies: ResponseCookies | None) -> Sequence[Cookie] | None: +def narrow_response_cookies(cookies: ResponseCookies | None) -> Sequence[Cookie]: """Given :class:`.types.ResponseCookies` as a :class:`typing.Mapping`, create a list of :class:`.datastructures.cookie.Cookie` from it, otherwise return ``cookies`` unchanged """ @@ -31,4 +31,4 @@ def narrow_response_cookies(cookies: ResponseCookies | None) -> Sequence[Cookie] tuple(Cookie(key=key, value=value) for key, value in cookies.items()) if isinstance(cookies, Mapping) else cookies - ) + ) or () diff --git a/litestar/_openapi/datastructures.py b/litestar/_openapi/datastructures.py index 91761e7eff..254b80b49e 100644 --- a/litestar/_openapi/datastructures.py +++ b/litestar/_openapi/datastructures.py @@ -243,6 +243,6 @@ def add_operation_id(self, operation_id: str) -> None: if operation_id in self.operation_ids: raise ImproperlyConfiguredException( "operation_ids must be unique, " - f"please ensure the value of 'operation_id' is either not set or unique for {operation_id}" + f"please ensure the value of 'operation_id' is either not set or unique for {operation_id!r}" ) self.operation_ids.add(operation_id) diff --git a/litestar/_openapi/parameters.py b/litestar/_openapi/parameters.py index 07b8f549ca..f296e0ef6f 100644 --- a/litestar/_openapi/parameters.py +++ b/litestar/_openapi/parameters.py @@ -94,8 +94,8 @@ def __init__( self.schema_creator = SchemaCreator.from_openapi_context(self.context, prefer_alias=True) self.route_handler = route_handler self.parameters = ParameterCollection(route_handler) - self.dependency_providers = route_handler.resolve_dependencies() - self.layered_parameters = route_handler.resolve_layered_parameters() + self.dependency_providers = route_handler.dependencies + self.layered_parameters = route_handler.parameter_field_definitions self.path_parameters = path_parameters def create_parameter(self, field_definition: FieldDefinition, parameter_name: str) -> Parameter: diff --git a/litestar/_openapi/path_item.py b/litestar/_openapi/path_item.py index 987c89335e..de48e906a3 100644 --- a/litestar/_openapi/path_item.py +++ b/litestar/_openapi/path_item.py @@ -36,7 +36,7 @@ def create_path_item(self) -> PathItem: A PathItem instance. """ for http_method, route_handler in self.route.route_handler_map.items(): - if not route_handler.resolve_include_in_schema(): + if not route_handler.include_in_schema: continue operation = self.create_operation_for_handler_method(route_handler, HttpMethod(http_method)) @@ -64,7 +64,7 @@ def create_operation_for_handler_method( request_body = None if data_field := signature_fields.get("data"): request_body = create_request_body( - self.context, route_handler.handler_id, route_handler.resolve_data_dto(), data_field + self.context, route_handler.handler_id, route_handler.data_dto, data_field ) raises_validation_error = bool(data_field or self._path_item.parameters or parameters) @@ -74,14 +74,14 @@ def create_operation_for_handler_method( return route_handler.operation_class( operation_id=operation_id, - tags=route_handler.resolve_tags() or None, + tags=sorted(route_handler.tags) if route_handler.tags else None, summary=route_handler.summary or SEPARATORS_CLEANUP_PATTERN.sub("", route_handler.handler_name.title()), description=self.create_description_for_handler(route_handler), deprecated=route_handler.deprecated, responses=responses, request_body=request_body, parameters=parameters or None, # type: ignore[arg-type] - security=route_handler.resolve_security() or None, + security=list(route_handler.security) if route_handler.security else None, ) def create_operation_id(self, route_handler: HTTPRouteHandler, http_method: HttpMethod) -> str: diff --git a/litestar/_openapi/plugin.py b/litestar/_openapi/plugin.py index 2727dc91c7..bee9d5b5df 100644 --- a/litestar/_openapi/plugin.py +++ b/litestar/_openapi/plugin.py @@ -198,7 +198,7 @@ def receive_route(self, route: BaseRoute) -> None: if not isinstance(route, HTTPRoute): return - if any(route_handler.resolve_include_in_schema() for route_handler in route.route_handler_map.values()): + if any(route_handler.include_in_schema for route_handler in route.route_handler_map.values()): # Force recompute the schema if a new route is added self._openapi = None self.included_routes[route.path] = route diff --git a/litestar/_openapi/responses.py b/litestar/_openapi/responses.py index 1be71838f2..77f0c51ce4 100644 --- a/litestar/_openapi/responses.py +++ b/litestar/_openapi/responses.py @@ -127,7 +127,7 @@ def create_success_response(self) -> OpenAPIResponse: else: media_type = self.route_handler.media_type - if dto := self.route_handler.resolve_return_dto(): + if dto := self.route_handler.return_dto: result = dto.create_openapi_schema( field_definition=self.field_definition, handler_id=self.route_handler.handler_id, @@ -209,7 +209,7 @@ def set_success_response_headers(self, response: OpenAPIResponse) -> None: else: schema_creator = SchemaCreator.from_openapi_context(self.context, generate_examples=False) - for response_header in self.route_handler.resolve_response_headers(): + for response_header in self.route_handler.response_headers: header = OpenAPIHeader() for attribute_name, attribute_value in ( (k, v) for k, v in asdict(response_header).items() if v is not None @@ -223,7 +223,7 @@ def set_success_response_headers(self, response: OpenAPIResponse) -> None: response.headers[response_header.name] = header - if cookies := self.route_handler.resolve_response_cookies(): + if cookies := self.route_handler.response_cookies: response.headers["Set-Cookie"] = OpenAPIHeader( schema=Schema( all_of=[create_cookie_schema(cookie=cookie) for cookie in sorted(cookies, key=attrgetter("key"))] diff --git a/litestar/app.py b/litestar/app.py index 23c893c490..f7ddfad4c9 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -1,9 +1,12 @@ from __future__ import annotations +import collections import inspect +import itertools import logging import os import warnings +from collections import defaultdict from contextlib import ( AbstractAsyncContextManager, AsyncExitStack, @@ -14,7 +17,7 @@ from functools import partial from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable, Mapping, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Iterable, Mapping, Sequence, TypedDict, cast from uuid import UUID from litestar._asgi import ASGIRouter @@ -28,10 +31,13 @@ from litestar.datastructures.state import State from litestar.events.emitter import BaseEventEmitterBackend, SimpleEventEmitter from litestar.exceptions import ( + ImproperlyConfiguredException, LitestarWarning, MissingDependencyException, NoRouteMatchFoundException, ) +from litestar.handlers import ASGIRouteHandler, BaseRouteHandler, HTTPRouteHandler, WebsocketRouteHandler +from litestar.handlers.http_handlers._options import create_options_handler from litestar.logging.config import LoggingConfig, get_logger_placeholder from litestar.middleware._internal.cors import CORSMiddleware from litestar.openapi.config import OpenAPIConfig @@ -45,12 +51,13 @@ ) from litestar.plugins.base import CLIPlugin from litestar.router import Router +from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute from litestar.stores.registry import StoreRegistry from litestar.types import Empty, TypeDecodersSequence -from litestar.types.internal_types import PathParameterDefinition, TemplateConfigType +from litestar.types.internal_types import PathParameterDefinition, RouteHandlerMapItem, TemplateConfigType from litestar.utils import deprecated, ensure_async_callable, join_paths, unique from litestar.utils.dataclass import extract_dataclass_items -from litestar.utils.predicates import is_async_callable +from litestar.utils.predicates import is_async_callable, is_class_and_subclass from litestar.utils.warnings import warn_pdb_on_exception if TYPE_CHECKING: @@ -67,7 +74,6 @@ from litestar.openapi.spec import SecurityRequirement from litestar.openapi.spec.open_api import OpenAPI from litestar.response import Response - from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute from litestar.stores.base import Store from litestar.types import ( AfterExceptionHookHandler, @@ -158,7 +164,9 @@ class Litestar(Router): "pdb_on_exception", "plugins", "response_cache_config", + "route_handler_method_map", "route_map", + "routes", "state", "stores", "template_engine", @@ -412,7 +420,6 @@ def __init__( self.get_logger: GetLogger = get_logger_placeholder self.logger: Logger | None = None - self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] self.after_exception = [ensure_async_callable(h) for h in config.after_exception] self.allowed_hosts = cast("AllowedHostsConfig | None", config.allowed_hosts) @@ -466,8 +473,7 @@ def __init__( response_cookies=config.response_cookies, response_headers=config.response_headers, return_dto=config.return_dto, - # route handlers are registered below - route_handlers=[], + route_handlers=config.route_handlers, security=config.security, signature_namespace=config.signature_namespace, signature_types=config.signature_types, @@ -480,8 +486,16 @@ def __init__( self.asgi_router = ASGIRouter(app=self) - for route_handler in config.route_handlers: - self.register(route_handler) + self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = self._build_routes( + self._reduce_handlers(self.route_handlers) + ) + self.route_handler_method_map = self._create_route_handler_method_map(self.routes) + + # we have merged everything registered initially, so let's ensure we don't keep + # unnecessary references to temporary objects around and let them be gc'ed + self.route_handlers = () + + self.asgi_router.construct_routing_trie() if self.logging_config: self.get_logger = self.logging_config.configure() @@ -661,28 +675,181 @@ def from_config(cls, config: AppConfig) -> Self: """ return cls(**dict(extract_dataclass_items(config))) - def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override] - """Register a route handler on the app. - - This method can be used to dynamically add endpoints to an application. - - Args: - value: An instance of :class:`Router <.router.Router>`, a subclass of - :class:`Controller <.controller.Controller>` or any function decorated by the route handler decorators. + @staticmethod + def _create_route_handler_method_map( + routes: Sequence[HTTPRoute | ASGIRoute | WebSocketRoute], + ) -> dict[str, RouteHandlerMapItem]: + """Map route paths to :class:`~litestar.types.internal_types.RouteHandlerMapItem` Returns: - None + A dictionary mapping paths to route handlers """ - routes = super().register(value=value) - + route_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict) for route in routes: - route_handlers = get_route_handlers(route) + if isinstance(route, HTTPRoute): + route_map[route.path] = route.route_handler_map # type: ignore[assignment] + else: + route_map[route.path]["websocket" if isinstance(route, WebSocketRoute) else "asgi"] = ( + route.route_handler + ) + + return route_map + + def _build_routes(self, route_handlers: Iterable[BaseRouteHandler]) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: + """Create routes for all the handlers""" + routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] + + # since http routes can have multiple handlers (the case when one path handles + # multiple methods - we do the last mile of the routing outside the trie and on + # the route itself), we first group them by path and then create one route for + # each path + http_path_groups: dict[str, list[HTTPRouteHandler]] = collections.defaultdict(list) + + for handler in route_handlers: + if isinstance(handler, HTTPRouteHandler): + for path in handler.paths: + http_path_groups[path].append(handler) + elif isinstance(handler, WebsocketRouteHandler): + for path in handler.paths: + routes.append(WebSocketRoute(path=path, route_handler=handler)) + elif isinstance(handler, ASGIRouteHandler): + for path in handler.paths: + routes.append(ASGIRoute(path=path, route_handler=handler)) + + for path, http_handlers in http_path_groups.items(): + routes.append( + HTTPRoute(path=path, route_handlers=_maybe_add_options_handler(path, http_handlers, root=self)) + ) + + for finalized_route in routes: + route_handlers = get_route_handlers(finalized_route) for route_handler in route_handlers: - route_handler.on_registration(self, route=route) + route_handler.on_registration(route=finalized_route, app=self) for plugin in self.plugins.receive_route: - plugin.receive_route(route) + plugin.receive_route(finalized_route) + + return routes + + def _iter_handlers( + self, handlers: Iterable[ControllerRouterHandler], bases: list[Router] + ) -> Generator[tuple[BaseRouteHandler, list[Router]], None, None]: + """Recursively iterate over 'handlers', returning tuples of all sub-handlers + (i.e. handlers included in a Router / Controller) and their preceding layers. + + handlers = [ + Router( + path="/one", + route_handlers=[ + Router(path="/two", route_handlers=[handler_one]), + handler_two, + ], + ) + ] + + would return: + + [ + (handler_one, [, ]), + (handler_two, []), + ] + + """ + for handler in handlers: + handler = self._validate_registration_value(handler) + if isinstance(handler, Router): + yield from self._iter_handlers(handler.route_handlers, bases=[handler, *bases]) + else: + yield handler, bases + + def _reduce_handlers(self, handlers: Iterable[ControllerRouterHandler]) -> Generator[BaseRouteHandler, None, None]: + """Reduce possibly nested 'handlers' by recursively iterating over them and their + sub-handlers (e.g. handlers inside a router), and merging all the options of all + the layers above into one new handler. This allows us to eliminate all the + intermediate layers and keep the configuration only on the handlers. + + Using path merging as an example: + + .. code-block:: python + + @get("/handler-one") + async def handler_one() -> None: + pass + + + @get("/handler-two") + async def handler_two() -> None: + pass + + + router = Router( + path="/router-one", + route_handlers=[ + handler_one, + Router(path="/router-two", route_handlers=[handler_two]), + ], + ) + + would be the equivalent of writing: + + .. code-block:: python + + @get("/router-one/handler-one") + async def handler_one() -> None: + pass + + @get("/router-one/router-two/handler-two") + async def handler_two() -> None: + pass + """ + for handler, bases in self._iter_handlers(handlers, bases=[self]): + yield handler.merge(*bases) + + def _validate_registration_value(self, value: ControllerRouterHandler) -> RouteHandlerType | Router: + """Ensure values passed to the register method are supported.""" + from litestar.controller import Controller + from litestar.handlers import ASGIRouteHandler, WebsocketListener + + if is_class_and_subclass(value, Controller): + return value().as_router() + + # this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass + if is_class_and_subclass(value, WebsocketListener): + return value().to_handler() # pyright: ignore + + if isinstance(value, Router): + if value is self: + raise ImproperlyConfiguredException("Cannot register a router on itself") + + return value + + if isinstance(value, (ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler)): + return value + + raise ImproperlyConfiguredException( + "Unsupported value passed to `Router.register`. " + "If you passed in a function or method, " + "make sure to decorate it first with one of the routing decorators" + ) + + def register(self, value: ControllerRouterHandler) -> None: + warnings.warn( + "Registering routes after the application instance has been " + "created is discouraged, as it might lead to unexpected behaviour " + "and is a costly operation. To register routes dynamically, a " + "plugin should be used where routes can be added to the " + "application via 'AppConfig' 'route_handlers' property", + category=LitestarWarning, + stacklevel=2, + ) + self.routes = [] + self.routes = self._build_routes( + itertools.chain( + self._reduce_handlers([value]), (h for route in self.routes for h in get_route_handlers(route)) + ) + ) + self.route_handler_method_map = self._create_route_handler_method_map(self.routes) self.asgi_router.construct_routing_trie() @@ -789,18 +956,6 @@ def get_membership_details(group_id: int, user_id: int) -> None: return join_paths(output) - @property - def route_handler_method_view(self) -> dict[str, list[str]]: - """Map route handlers to paths. - - Returns: - A dictionary of router handlers and lists of paths as strings - """ - route_map: dict[str, list[str]] = { - handler: [route.path for route in routes] for handler, routes in self.asgi_router.route_mapping.items() - } - return route_map - def _create_asgi_handler(self) -> ASGIApp: """Create an ASGIApp that wraps the ASGI router inside an exception handler. @@ -859,3 +1014,14 @@ def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None: None """ self.event_emitter.emit(event_id, *args, **kwargs) + + +def _maybe_add_options_handler( + path: str, http_handlers: list[HTTPRouteHandler], root: Router +) -> list[HTTPRouteHandler]: + handler_methods = {method for handler in http_handlers for method in handler.http_methods} + if "OPTIONS" not in handler_methods: + options_handler = create_options_handler(path=path, allow_methods={*handler_methods, "OPTIONS"}) # pyright: ignore + options_handler = options_handler.merge(root) + return [*http_handlers, options_handler] + return http_handlers diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 065312d7ed..c66f44e066 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -137,7 +137,7 @@ async def json(self) -> Any: else: body = await self.body() self._json = self._connection_state.json = decode_json( - body or b"null", type_decoders=self.route_handler.resolve_type_decoders() + body or b"null", type_decoders=self.route_handler.type_decoders ) return self._json @@ -153,7 +153,7 @@ async def msgpack(self) -> Any: else: body = await self.body() self._msgpack = self._connection_state.msgpack = decode_msgpack( - body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders() + body or b"\xc0", type_decoders=self.route_handler.type_decoders ) return self._msgpack @@ -190,7 +190,7 @@ async def stream(self) -> AsyncGenerator[bytes, None]: # float is slightly faster than checking if a value is 'None' and then # comparing it to an int. since we expect a limit to be set most of the # time, this is a bit more efficient - max_content_length = self.route_handler.resolve_request_max_body_size() or math.inf + max_content_length = self.route_handler.request_max_body_size or math.inf # if the 'content-length' header is set, and exceeds the limit, we can bail # out early before reading anything diff --git a/litestar/connection/websocket.py b/litestar/connection/websocket.py index 0c7bc04404..0494f3bca5 100644 --- a/litestar/connection/websocket.py +++ b/litestar/connection/websocket.py @@ -211,7 +211,7 @@ async def receive_json(self, mode: WebSocketMode = "text") -> Any: An arbitrary value """ data = await self.receive_data(mode=mode) - return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + return decode_json(value=data, type_decoders=self.route_handler.type_decoders) async def receive_msgpack(self) -> Any: """Receive data and decode it as MessagePack. @@ -223,7 +223,7 @@ async def receive_msgpack(self) -> Any: An arbitrary value """ data = await self.receive_data(mode="binary") - return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + return decode_msgpack(value=data, type_decoders=self.route_handler.type_decoders) async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]: """Continuously receive data and yield it, decoding it as JSON in the process. @@ -232,7 +232,7 @@ async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, N mode: Socket mode to use. Either ``text`` or ``binary`` """ async for data in self.iter_data(mode): - yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + yield decode_json(value=data, type_decoders=self.route_handler.type_decoders) async def iter_msgpack(self) -> AsyncGenerator[Any, None]: """Continuously receive data and yield it, decoding it as MessagePack in the @@ -243,7 +243,7 @@ async def iter_msgpack(self) -> AsyncGenerator[Any, None]: """ async for data in self.iter_data(mode="binary"): - yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + yield decode_msgpack(value=data, type_decoders=self.route_handler.type_decoders) async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None: """Send a 'websocket.send' event. diff --git a/litestar/controller.py b/litestar/controller.py index 331e1037ad..1b6f70c480 100644 --- a/litestar/controller.py +++ b/litestar/controller.py @@ -2,7 +2,6 @@ import types from collections import defaultdict -from copy import deepcopy from operator import attrgetter from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast @@ -60,7 +59,6 @@ class Controller: "include_in_schema", "middleware", "opt", - "owner", "parameters", "path", "request_class", @@ -121,11 +119,6 @@ class Controller: """A string key mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. """ - owner: Router - """The :class:`Router <.router.Router>` or :class:`Litestar ` app that owns the controller. - - This value is set internally by Litestar and it should not be set when subclassing the controller. - """ parameters: ParametersMap | None """A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application paths.""" path: str @@ -174,13 +167,10 @@ class Controller: handlers under the controller. """ - def __init__(self, owner: Router) -> None: + def __init__(self) -> None: """Initialize a controller. Should only be called by routers as part of controller registration. - - Args: - owner: An instance of :class:`Router <.router.Router>` """ # Since functions set on classes are bound, we need replace the bound instance with the class version for key in ("after_request", "after_response", "before_request"): @@ -211,12 +201,11 @@ def __init__(self, owner: Router) -> None: self.response_cookies = narrow_response_cookies(self.response_cookies) self.response_headers = narrow_response_headers(self.response_headers) self.path = normalize_path(self.path or "/") - self.owner = owner def as_router(self) -> Router: from litestar.router import Router - router = Router( + return Router( path=self.path, route_handlers=self.get_route_handlers(), # type: ignore[arg-type] after_request=self.after_request, @@ -246,11 +235,9 @@ def as_router(self) -> Router: websocket_class=self.websocket_class, request_max_body_size=self.request_max_body_size, ) - router.owner = self.owner - return router def get_route_handlers(self) -> list[BaseRouteHandler]: - """Get a controller's route handlers and set the controller as the handlers' owner. + """Get a controller's route handlers Returns: A list containing a copy of the route handlers defined on the controller @@ -258,19 +245,16 @@ def get_route_handlers(self) -> list[BaseRouteHandler]: route_handlers: list[BaseRouteHandler] = [] controller_names = set(dir(Controller)) - self_handlers = [ + self_handlers: list[BaseRouteHandler] = [ getattr(self, name) for name in dir(self) if name not in controller_names and isinstance(getattr(self, name), BaseRouteHandler) ] self_handlers.sort(key=attrgetter("handler_id")) for self_handler in self_handlers: - route_handler = deepcopy(self_handler) # at the point we get a reference to the handler function, it's unbound, so # we replace it with a regular bound method here - route_handler.fn = types.MethodType(route_handler.fn, self) - route_handler.owner = self - route_handlers.append(route_handler) + route_handlers.append(self_handler._with_changes(fn=types.MethodType(self_handler.fn, self))) self.validate_route_handlers(route_handlers=route_handlers) diff --git a/litestar/di.py b/litestar/di.py index b47f13478f..142e126f16 100644 --- a/litestar/di.py +++ b/litestar/di.py @@ -3,10 +3,14 @@ from inspect import isasyncgenfunction, isclass, isgeneratorfunction from typing import TYPE_CHECKING, Any +from litestar._signature import SignatureModel from litestar.exceptions import ImproperlyConfiguredException -from litestar.types import Empty +from litestar.plugins import DIPlugin, PluginRegistry +from litestar.types import Empty, TypeDecodersSequence from litestar.utils import ensure_async_callable +from litestar.utils.helpers import unwrap_partial from litestar.utils.predicates import is_async_callable +from litestar.utils.signature import ParsedSignature from litestar.utils.warnings import ( warn_implicit_sync_to_thread, warn_sync_to_thread_with_async_callable, @@ -14,9 +18,8 @@ ) if TYPE_CHECKING: - from litestar._signature import SignatureModel + from litestar.dto import AbstractDTO from litestar.types import AnyCallable - from litestar.utils.signature import ParsedSignature __all__ = ("Provide",) @@ -25,19 +28,17 @@ class Provide: """Wrapper class for dependency injection""" __slots__ = ( + "_parsed_fn_signature", + "_signature_model", "dependency", "has_async_generator_dependency", "has_sync_callable", "has_sync_generator_dependency", - "parsed_fn_signature", - "signature_model", "sync_to_thread", "use_cache", "value", ) - parsed_fn_signature: ParsedSignature - signature_model: type[SignatureModel] dependency: AnyCallable def __init__( @@ -90,6 +91,50 @@ def __init__( self.sync_to_thread = bool(sync_to_thread) self.use_cache = use_cache self.value: Any = Empty + self._parsed_fn_signature: ParsedSignature | None = None + self._signature_model: type[SignatureModel] | None = None + + @property + def signature_model(self) -> type[SignatureModel]: + if self._signature_model is None: + raise ValueError(f"Cannot access signature model of Provider {self} because it is not finalized") + return self._signature_model + + @property + def parsed_fn_signature(self) -> ParsedSignature: + if self._parsed_fn_signature is None: + raise ValueError(f"Cannot access parsed signature of Provider {self} because it is not finalized") + return self._parsed_fn_signature + + def finalize( + self, + *, + plugins: PluginRegistry, + signature_namespace: dict[str, Any], + dependency_keys: set[str], + data_dto: type[AbstractDTO] | None, + type_decoders: TypeDecodersSequence, + ) -> None: + if self._parsed_fn_signature is None: + dependency = unwrap_partial(self.dependency) + plugin = next( + (p for p in plugins.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), + None, + ) + if plugin: + signature, init_type_hints = plugin.get_typed_init(dependency) + self._parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) + else: + self._parsed_fn_signature = ParsedSignature.from_fn(dependency, signature_namespace) + + if self._signature_model is None: + self._signature_model = SignatureModel.create( + dependency_name_set=dependency_keys, + fn=self.dependency, + parsed_signature=self.parsed_fn_signature, + data_dto=data_dto, + type_decoders=type_decoders, + ) async def __call__(self, **kwargs: Any) -> Any: """Call the provider's dependency.""" diff --git a/litestar/dto/_backend.py b/litestar/dto/_backend.py index 4d9651a19e..32e1d8c5fc 100644 --- a/litestar/dto/_backend.py +++ b/litestar/dto/_backend.py @@ -233,7 +233,7 @@ def parse_raw(self, raw: bytes, asgi_connection: ASGIConnection) -> Struct | Col if (content_type := getattr(asgi_connection, "content_type", None)) and (media_type := content_type[0]): request_encoding = media_type - type_decoders = asgi_connection.route_handler.resolve_type_decoders() + type_decoders = asgi_connection.route_handler.type_decoders if request_encoding == RequestEncodingType.MESSAGEPACK: result = decode_msgpack(value=raw, target_type=self.annotation, type_decoders=type_decoders, strict=False) diff --git a/litestar/exceptions/responses/_debug_response.py b/litestar/exceptions/responses/_debug_response.py index 99e8c8709a..84b872cb55 100644 --- a/litestar/exceptions/responses/_debug_response.py +++ b/litestar/exceptions/responses/_debug_response.py @@ -202,7 +202,7 @@ def create_debug_response(request: Request, exc: Exception) -> Response: def _get_type_encoders_for_request(request: Request) -> TypeEncodersMap | None: try: - return request.route_handler.resolve_type_encoders() + return request.route_handler.type_encoders # we might be in a 404, or before we could resolve the handler, so this # could potentially error out. In this case we fall back on the application # type encoders diff --git a/litestar/handlers/asgi_handlers.py b/litestar/handlers/asgi_handlers.py index a7c939fb25..ff7463e06f 100644 --- a/litestar/handlers/asgi_handlers.py +++ b/litestar/handlers/asgi_handlers.py @@ -12,13 +12,14 @@ if TYPE_CHECKING: - from litestar import Litestar + from litestar import Litestar, Router from litestar.connection import ASGIConnection from litestar.routes import BaseRoute from litestar.types import ( AsyncAnyCallable, ExceptionHandlersMap, Guard, + ParametersMap, ) @@ -40,6 +41,7 @@ def __init__( is_mount: bool = False, signature_namespace: Mapping[str, Any] | None = None, copy_scope: bool | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: """Route handler for ASGI routes. @@ -64,6 +66,7 @@ def __init__( type_encoders: A mapping of types to callables that transform them into types supported for serialization. copy_scope: Copy the ASGI 'scope' before calling the mounted application. Should be set to 'True' unless side effects via scope mutations by the mounted ASGI application are intentional + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions **kwargs: Any additional kwarg - will be set in the opt dictionary. """ self.is_mount = is_mount @@ -77,11 +80,12 @@ def __init__( name=name, opt=opt, signature_namespace=signature_namespace, + parameters=parameters, **kwargs, ) - def on_registration(self, app: Litestar, route: BaseRoute) -> None: - super().on_registration(app, route=route) + def on_registration(self, route: BaseRoute, app: Litestar) -> None: + super().on_registration(app=app, route=route) if self.copy_scope is None: warnings.warn( @@ -93,6 +97,12 @@ def on_registration(self, app: Litestar, route: BaseRoute) -> None: stacklevel=1, ) + def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]: + merge_opts = super()._get_merge_opts(others) + merge_opts["is_mount"] = self.is_mount + merge_opts["copy_scope"] = self.copy_scope + return merge_opts + def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" super()._validate_handler_function() @@ -119,7 +129,7 @@ async def handle(self, connection: ASGIConnection[ASGIRouteHandler, Any, Any, An None """ - if self.resolve_guards(): + if self.guards: await self.authorize_connection(connection=connection) await self.fn(scope=connection.scope, receive=connection.receive, send=connection.send) diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index 039bf602a6..511bae581e 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -1,14 +1,14 @@ from __future__ import annotations -from copy import copy +import functools from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, cast from litestar._signature import SignatureModel from litestar.di import Provide from litestar.dto import DTOData -from litestar.exceptions import ImproperlyConfiguredException -from litestar.plugins import DIPlugin, PluginRegistry +from litestar.exceptions import ImproperlyConfiguredException, LitestarException +from litestar.router import Router from litestar.serialization import default_deserializer, default_serializer from litestar.types import ( Dependencies, @@ -16,11 +16,14 @@ ExceptionHandlersMap, Guard, Middleware, + ParametersMap, TypeDecodersSequence, TypeEncodersMap, ) from litestar.typing import FieldDefinition -from litestar.utils import ensure_async_callable, get_name, normalize_path +from litestar.utils import ensure_async_callable, get_name, join_paths, normalize_path +from litestar.utils.deprecation import deprecated +from litestar.utils.empty import value_or_default from litestar.utils.helpers import unwrap_partial from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace, merge_signature_namespaces @@ -30,12 +33,10 @@ from litestar._kwargs import KwargsModel from litestar.app import Litestar from litestar.connection import ASGIConnection - from litestar.controller import Controller from litestar.dto import AbstractDTO - from litestar.params import ParameterKwarg - from litestar.router import Router from litestar.routes import BaseRoute - from litestar.types import AnyCallable, AsyncAnyCallable, ExceptionHandler + from litestar.types import AsyncAnyCallable + from litestar.types.callable_types import AnyCallable, AsyncGuard from litestar.types.empty import EmptyType __all__ = ("BaseRouteHandler",) @@ -48,29 +49,22 @@ class BaseRouteHandler: """ __slots__ = ( + "_dto", + "_parameter_field_definitions", "_parsed_data_field", "_parsed_fn_signature", "_parsed_return_field", - "_resolved_data_dto", - "_resolved_dependencies", - "_resolved_guards", - "_resolved_layered_parameters", - "_resolved_return_dto", - "_resolved_signature_namespace", - "_resolved_type_decoders", - "_resolved_type_encoders", - "_signature_model", + "_resolved_signature_model", + "_return_dto", "dependencies", - "dto", "exception_handlers", "fn", "guards", "middleware", "name", "opt", - "owner", + "parameters", "paths", - "return_dto", "signature_namespace", "type_decoders", "type_encoders", @@ -91,6 +85,7 @@ def __init__( return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, signature_types: Sequence[Any] | None = None, + parameters: ParametersMap | None = None, type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, @@ -121,48 +116,126 @@ def __init__( These types will be added to the signature namespace using their ``__name__`` attribute. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions **kwargs: Any additional kwarg - will be set in the opt dictionary. """ self._parsed_fn_signature: ParsedSignature | EmptyType = Empty self._parsed_return_field: FieldDefinition | EmptyType = Empty self._parsed_data_field: FieldDefinition | None | EmptyType = Empty - self._resolved_data_dto: type[AbstractDTO] | None | EmptyType = Empty - self._resolved_dependencies: dict[str, Provide] | EmptyType = Empty - self._resolved_guards: list[Guard] | EmptyType = Empty - self._resolved_layered_parameters: dict[str, FieldDefinition] | EmptyType = Empty - self._resolved_return_dto: type[AbstractDTO] | None | EmptyType = Empty - self._resolved_signature_namespace: dict[str, Any] | EmptyType = Empty - self._resolved_type_decoders: TypeDecodersSequence | EmptyType = Empty - self._resolved_type_encoders: TypeEncodersMap | EmptyType = Empty - self._signature_model: type[SignatureModel] | EmptyType = Empty - - self.dependencies = dependencies - self.dto = dto - self.exception_handlers = exception_handlers - self.guards = guards - self.middleware = middleware + self._parameter_field_definitions: dict[str, FieldDefinition] | EmptyType = Empty + self._resolved_signature_model: type[SignatureModel] | EmptyType = Empty + + self.dependencies = ( + { + key: provider if isinstance(provider, Provide) else Provide(provider) + for key, provider in dependencies.items() + } + if dependencies + else {} + ) + self._dto = dto + self._return_dto = return_dto + self.exception_handlers = exception_handlers or {} + self.guards: tuple[AsyncGuard, ...] = tuple(ensure_async_callable(guard) for guard in guards) if guards else () + self.middleware = tuple(middleware) if middleware else () self.name = name self.opt = dict(opt or {}) self.opt.update(**kwargs) - self.owner: Controller | Router | None = None - self.return_dto = return_dto self.signature_namespace = add_types_to_signature_namespace( signature_types or [], dict(signature_namespace or {}) ) - self.type_decoders = type_decoders - self.type_encoders = type_encoders + self.type_decoders = type_decoders or () + self.type_encoders = type_encoders or {} self.paths = ( {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore[arg-type] ) - self.fn = self._prepare_fn(fn) + self.fn = fn + self.parameters = parameters or {} + + def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]: + """Get kwargs for .merge. + + This is effectively the same as doing: - def _prepare_fn(self, fn: AsyncAnyCallable) -> AsyncAnyCallable: - return fn + for other in others: + handler = handler.merge(other) + + The downside of that approach is that it creates a bunch of intermediate + instances requires every subclass that adds properties to re-implement all the + merging logic. + With this approach, the subclass can instead override this method, call + super()._get_merge_opts(), and extend the dict returned by it. + + The downside here is that we don't get type safety (as long as annotating + **kwargs with TypedDicts isn't supported anyway). + + The plan is for this to go away in version 4, where we can move to fully static + handler config, separating the logic and configuration entirely. + """ + path = ( + functools.reduce( + lambda a, b: join_paths([a, b]), + (o.path for o in reversed(others)), + ) + if others + else "" + ) + merge_opts: dict[str, Any] = { + "fn": self.fn, + "name": self.name, + "path": [join_paths([path, p]) for p in self.paths], + } + + other: BaseRouteHandler | Router + for other in (self, *others): # type: ignore[assignment] + merge_opts["dependencies"] = {**other.dependencies, **merge_opts.get("dependencies", {})} + merge_opts["exception_handlers"] = {**other.exception_handlers, **merge_opts.get("exception_handlers", {})} + merge_opts["guards"] = (*other.guards, *merge_opts.get("guards", ())) + + merge_opts["middleware"] = (*other.middleware, *merge_opts.get("middleware", ())) + merge_opts["opt"] = {**other.opt, **merge_opts.get("opt", {})} + merge_opts["type_decoders"] = (*merge_opts.get("type_decoders", ()), *other.type_decoders) + merge_opts["type_encoders"] = {**merge_opts.get("type_encoders", {}), **other.type_encoders} + merge_opts["parameters"] = {**merge_opts.get("parameters", {}), **other.parameters} + merge_opts["signature_namespace"] = merge_signature_namespaces( + merge_opts.get("signature_namespace", {}), other.signature_namespace + ) + + # '.dto' on the router is the dto config value supplied by the users, + # whereas '.dto' on the handler is the fully resolved dto. The dto config on + # the handler is stored under '._dto', so we have to do this little workaround + if other is not self: + other = cast(Router, other) # mypy cannot narrow with the 'is not self' check + merge_opts["dto"] = value_or_default(merge_opts.get("dto", Empty), other.dto) + merge_opts["return_dto"] = value_or_default(merge_opts.get("return_dto", Empty), other.return_dto) + + merge_opts["dto"] = value_or_default(self._dto, merge_opts.get("dto", Empty)) + merge_opts["return_dto"] = value_or_default(self._return_dto, merge_opts.get("return_dto", Empty)) + + # due to the way we're traversing over the app layers, the middleware stack is + # constructed in the wrong order (handler > application). reversing the order + # here is easier than handling it correctly at every intermediary step. + # + # we only call this if 'others' is non-empty, to ensure we don't change anything + # if no layers have been merged (happens in '._with_changes' for example) + if others: + merge_opts["middleware"] = tuple(reversed(merge_opts["middleware"])) + + return merge_opts + + def _with_changes(self, **kwargs: Any) -> Self: + """Return a new instance of the handler, replacing attributes specified in **kwargs""" + opts = self._get_merge_opts(()) + opts.update(kwargs) + return type(self)(**opts) + + def merge(self, *others: Router) -> Self: + return type(self)(**self._get_merge_opts(others)) @property def handler_id(self) -> str: """A unique identifier used for generation of DTOs.""" - return f"{self!s}::{sum(id(layer) for layer in self.ownership_layers)}" + return f"{self!s}::{id(self)}" @property def default_deserializer(self) -> Callable[[Any, Any], Any]: @@ -172,7 +245,7 @@ def default_deserializer(self) -> Callable[[Any, Any], Any]: A default deserializer for the route handler. """ - return partial(default_deserializer, type_decoders=self.resolve_type_decoders()) + return partial(default_deserializer, type_decoders=self.type_decoders) @property def default_serializer(self) -> Callable[[Any], Any]: @@ -182,7 +255,7 @@ def default_serializer(self) -> Callable[[Any], Any]: A default serializer for the route handler. """ - return partial(default_serializer, type_encoders=self.resolve_type_encoders()) + return partial(default_serializer, type_encoders=self.type_encoders) @property def signature_model(self) -> type[SignatureModel]: @@ -192,15 +265,15 @@ def signature_model(self) -> type[SignatureModel]: A signature model for the route handler. """ - if self._signature_model is Empty: - self._signature_model = SignatureModel.create( - dependency_name_set=self.dependency_name_set, + if self._resolved_signature_model is Empty: + self._resolved_signature_model = SignatureModel.create( + dependency_name_set=set(self.dependencies.keys()), fn=cast("AnyCallable", self.fn), parsed_signature=self.parsed_fn_signature, - data_dto=self.resolve_data_dto(), - type_decoders=self.resolve_type_decoders(), + data_dto=self.data_dto, + type_decoders=self.type_decoders, ) - return self._signature_model + return self._resolved_signature_model @property def parsed_fn_signature(self) -> ParsedSignature: @@ -212,9 +285,7 @@ def parsed_fn_signature(self) -> ParsedSignature: A ParsedSignature instance """ if self._parsed_fn_signature is Empty: - self._parsed_fn_signature = ParsedSignature.from_fn( - unwrap_partial(self.fn), self.resolve_signature_namespace() - ) + self._parsed_fn_signature = ParsedSignature.from_fn(unwrap_partial(self.fn), self.signature_namespace) return self._parsed_fn_signature @@ -242,196 +313,116 @@ def handler_name(self) -> str: """ return get_name(unwrap_partial(self.fn)) - @property - def dependency_name_set(self) -> set[str]: - """Set of all dependency names provided in the handler's ownership layers.""" - layered_dependencies = (layer.dependencies or {} for layer in self.ownership_layers) - return {name for layer in layered_dependencies for name in layer} # pyright: ignore - - @property - def ownership_layers(self) -> list[Self | Controller | Router]: - """Return the handler layers from the app down to the route handler. - - ``app -> ... -> route handler`` - """ - layers = [] - - cur: Any = self - while cur: - layers.append(cur) - cur = cur.owner - - return list(reversed(layers)) - - @property - def app(self) -> Litestar: - return cast("Litestar", self.ownership_layers[0]) + def _raise_not_registered(self) -> NoReturn: + raise LitestarException( + f"Handler {self!r}: Accessing this attribute is unsafe until the handler has been " + "registered with an application, as it may yield different results after registration." + ) + @deprecated("3.0", removal_in="4.0", alternative=".type_encoders attribute") def resolve_type_encoders(self) -> TypeEncodersMap: """Return a merged type_encoders mapping. - This method is memoized so the computation occurs only once. - Returns: A dict of type encoders """ - if self._resolved_type_encoders is Empty: - self._resolved_type_encoders = {} - for layer in self.ownership_layers: - if type_encoders := getattr(layer, "type_encoders", None): - self._resolved_type_encoders.update(type_encoders) - return cast("TypeEncodersMap", self._resolved_type_encoders) + return self.type_encoders + @deprecated("3.0", removal_in="4.0", alternative=".type_decoders attribute") def resolve_type_decoders(self) -> TypeDecodersSequence: """Return a merged type_encoders mapping. - This method is memoized so the computation occurs only once. - Returns: A dict of type encoders """ - if self._resolved_type_decoders is Empty: - self._resolved_type_decoders = [] - for layer in self.ownership_layers: - if type_decoders := getattr(layer, "type_decoders", None): - self._resolved_type_decoders.extend(list(type_decoders)) - return cast("TypeDecodersSequence", self._resolved_type_decoders) + return self.type_decoders + @deprecated("3.0", removal_in="4.0", alternative=".parameter_field_definitions property") def resolve_layered_parameters(self) -> dict[str, FieldDefinition]: - """Return all parameters declared above the handler.""" - if self._resolved_layered_parameters is Empty: - parameter_kwargs: dict[str, ParameterKwarg] = {} + return self.parameter_field_definitions - for layer in self.ownership_layers: - parameter_kwargs.update(getattr(layer, "parameters", {}) or {}) - - self._resolved_layered_parameters = { + @property + def parameter_field_definitions(self) -> dict[str, FieldDefinition]: + """Return all parameters declared above the handler.""" + if self._parameter_field_definitions is Empty: + self._parameter_field_definitions = { key: FieldDefinition.from_kwarg(name=key, annotation=parameter.annotation, kwarg_definition=parameter) - for key, parameter in parameter_kwargs.items() + for key, parameter in self.parameters.items() } + return self._parameter_field_definitions - return self._resolved_layered_parameters - - def resolve_guards(self) -> list[Guard]: + @deprecated("3.0", removal_in="4.0", alternative=".guards attribute") + def resolve_guards(self) -> tuple[Guard, ...]: """Return all guards in the handlers scope, starting from highest to current layer.""" - if self._resolved_guards is Empty: - self._resolved_guards = [] - for layer in self.ownership_layers: - self._resolved_guards.extend(layer.guards or []) # pyright: ignore + return self.guards - self._resolved_guards = cast( - "list[Guard]", [ensure_async_callable(guard) for guard in self._resolved_guards] - ) + @deprecated("3.0", removal_in="4.0", alternative=".dependencies attribute") + def resolve_dependencies(self) -> dict[str, Provide]: + """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" - return self._resolved_guards + return self.dependencies - def _get_plugin_registry(self) -> PluginRegistry | None: - from litestar.app import Litestar + def _finalize_dependencies(self, app: Litestar) -> None: + dependencies: dict[str, Provide] = {} - root_owner = self.ownership_layers[0] - if isinstance(root_owner, Litestar): - return root_owner.plugins - return None + # keep track of which providers are available for each dependency + provider_keys: dict[Any, str] = {} - def resolve_dependencies(self) -> dict[str, Provide]: - """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" - plugin_registry = self._get_plugin_registry() - if self._resolved_dependencies is Empty: - self._resolved_dependencies = {} - for layer in self.ownership_layers: - for key, provider in (layer.dependencies or {}).items(): - self._resolved_dependencies[key] = self._resolve_dependency( - key=key, provider=provider, plugin_registry=plugin_registry - ) - - return self._resolved_dependencies - - def _resolve_dependency( - self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None - ) -> Provide: - if not isinstance(provider, Provide): - provider = Provide(provider) - - if self._resolved_dependencies is not Empty: # pragma: no cover - self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider) - - if not getattr(provider, "parsed_fn_signature", None): - dependency = unwrap_partial(provider.dependency) - plugin: DIPlugin | None = None - if plugin_registry: - plugin = next( - (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), - None, + for key, provider in self.dependencies.items(): + # ensure that if a provider for this dependency has already been registered, + # registering this provider again is only allowed as an override, i.e. with + # the same key + if (existing_key := provider_keys.get(provider.dependency)) and existing_key != key: + raise ImproperlyConfiguredException( + f"Provider for {provider.dependency!r} with key {key!r} is already defined under a different key " + f"{existing_key!r}. If you wish to override a provider, it must have the same key." ) - if plugin: - signature, init_type_hints = plugin.get_typed_init(dependency) - provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) - else: - provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace()) - - if not getattr(provider, "signature_model", None): - provider.signature_model = SignatureModel.create( - dependency_name_set=self.dependency_name_set, - fn=provider.dependency, - parsed_signature=provider.parsed_fn_signature, - data_dto=self.resolve_data_dto(), - type_decoders=self.resolve_type_decoders(), + + provider.finalize( + plugins=app.plugins, + signature_namespace=self.signature_namespace, + data_dto=self.data_dto, + dependency_keys=set(self.dependencies), + type_decoders=self.type_decoders, ) - return provider + provider_keys[provider.dependency] = key + dependencies[key] = provider - def resolve_middleware(self) -> list[Middleware]: - """Build the middleware stack for the RouteHandler and return it. + @deprecated("3.0", removal_in="4.0", alternative=".middleware attribute") + def resolve_middleware(self) -> tuple[Middleware, ...]: + """Return registered middlewares""" - The middlewares are added from top to bottom (``app -> router -> controller -> route handler``) and then - reversed. - """ - resolved_middleware: list[Middleware] = [] - for layer in self.ownership_layers: - resolved_middleware.extend(layer.middleware or []) # pyright: ignore - return list(reversed(resolved_middleware)) + return self.middleware + @deprecated("3.0", removal_in="4.0", alternative=".exception_handlers attribute") def resolve_exception_handlers(self) -> ExceptionHandlersMap: """Resolve the exception_handlers by starting from the route handler and moving up. This method is memoized so the computation occurs only once. """ - resolved_exception_handlers: dict[int | type[Exception], ExceptionHandler] = {} - for layer in self.ownership_layers: - resolved_exception_handlers.update(layer.exception_handlers or {}) # pyright: ignore - return resolved_exception_handlers - def resolve_opts(self) -> None: - """Build the route handler opt dictionary by going from top to bottom. - - When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the - layer closest to the response handler will take precedence. - """ - - opt: dict[str, Any] = {} - for layer in self.ownership_layers: - opt.update(layer.opt or {}) # pyright: ignore - - self.opt = opt + return self.exception_handlers + @deprecated("3.0", removal_in="4.0", alternative=".signature_namespace attribute") def resolve_signature_namespace(self) -> dict[str, Any]: - """Build the route handler signature namespace dictionary by going from top to bottom. + """Build the route handler signature namespace dictionary by going from top to bottom""" - When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the - layer closest to the response handler will take precedence. - """ - if self._resolved_signature_namespace is Empty: - ns: dict[str, Any] = {} - for layer in self.ownership_layers: - merge_signature_namespaces( - signature_namespace=ns, additional_signature_namespace=layer.signature_namespace - ) - self._resolved_signature_namespace = ns - return self._resolved_signature_namespace + return self.signature_namespace + + @property + def data_dto(self) -> type[AbstractDTO] | None: + if self._dto is Empty: + self._raise_not_registered() + return self._dto + @deprecated("3.0", removal_in="4.0", alternative=".data_dto attribute") def resolve_data_dto(self) -> type[AbstractDTO] | None: + return self.data_dto + + def _resolve_data_dto(self, app: Litestar) -> type[AbstractDTO] | None: """Resolve the data_dto by starting from the route handler and moving up. If a handler is found it is returned, otherwise None is set. This method is memoized so the computation occurs only once. @@ -439,34 +430,40 @@ def resolve_data_dto(self) -> type[AbstractDTO] | None: Returns: An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` """ - if self._resolved_data_dto is Empty: - if data_dtos := cast( - "list[type[AbstractDTO] | None]", - [layer.dto for layer in self.ownership_layers if layer.dto is not Empty], - ): - data_dto: type[AbstractDTO] | None = data_dtos[-1] - elif self.parsed_data_field and ( - plugins_for_data_type := [ + data_dto: type[AbstractDTO] | None = None + if (_data_dto := self._dto) is not Empty: + data_dto = _data_dto + elif self.parsed_data_field and ( + plugin_for_data_type := next( + ( plugin - for plugin in self.app.plugins.serialization + for plugin in app.plugins.serialization if self.parsed_data_field.match_predicate_recursively(plugin.supports_type) - ] - ): - data_dto = plugins_for_data_type[0].create_dto_for_type(self.parsed_data_field) - else: - data_dto = None - - if self.parsed_data_field and data_dto: - data_dto.create_for_field_definition( - field_definition=self.parsed_data_field, - handler_id=self.handler_id, - ) + ), + None, + ) + ): + data_dto = plugin_for_data_type.create_dto_for_type(self.parsed_data_field) - self._resolved_data_dto = data_dto + if self.parsed_data_field and data_dto: + data_dto.create_for_field_definition( + field_definition=self.parsed_data_field, + handler_id=self.handler_id, + ) + + return data_dto - return self._resolved_data_dto + @property + def return_dto(self) -> type[AbstractDTO] | None: + if self._return_dto is Empty: + self._raise_not_registered() + return self._return_dto + @deprecated("3.0", removal_in="4.0", alternative=".return_dto attribute") def resolve_return_dto(self) -> type[AbstractDTO] | None: + return self.return_dto + + def _resolve_return_dto(self, app: Litestar, data_dto: type[AbstractDTO] | None) -> type[AbstractDTO] | None: """Resolve the return_dto by starting from the route handler and moving up. If a handler is found it is returned, otherwise None is set. This method is memoized so the computation occurs only once. @@ -474,72 +471,56 @@ def resolve_return_dto(self) -> type[AbstractDTO] | None: Returns: An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` """ - if self._resolved_return_dto is Empty: - if return_dtos := cast( - "list[type[AbstractDTO] | None]", - [layer.return_dto for layer in self.ownership_layers if layer.return_dto is not Empty], - ): - return_dto: type[AbstractDTO] | None = return_dtos[-1] - elif plugins_for_return_type := [ + if (_return_dto := self._return_dto) is not Empty: + return_dto: type[AbstractDTO] | None = _return_dto + elif plugin_for_return_type := next( + ( plugin - for plugin in self.app.plugins.serialization + for plugin in app.plugins.serialization if self.parsed_return_field.match_predicate_recursively(plugin.supports_type) - ]: - return_dto = plugins_for_return_type[0].create_dto_for_type(self.parsed_return_field) - else: - return_dto = self.resolve_data_dto() - - if return_dto and return_dto.is_supported_model_type_field(self.parsed_return_field): - return_dto.create_for_field_definition( - field_definition=self.parsed_return_field, - handler_id=self.handler_id, - ) - self._resolved_return_dto = return_dto - else: - self._resolved_return_dto = None + ), + None, + ): + return_dto = plugin_for_return_type.create_dto_for_type(self.parsed_return_field) + else: + return_dto = data_dto + + if return_dto and return_dto.is_supported_model_type_field(self.parsed_return_field): + return_dto.create_for_field_definition( + field_definition=self.parsed_return_field, + handler_id=self.handler_id, + ) + resolved_return_dto = return_dto + else: + resolved_return_dto = None - return self._resolved_return_dto + return resolved_return_dto async def authorize_connection(self, connection: ASGIConnection) -> None: """Ensure the connection is authorized by running all the route guards in scope.""" - for guard in self.resolve_guards(): - await guard(connection, copy(self)) # type: ignore[misc] - - @staticmethod - def _validate_dependency_is_unique(dependencies: dict[str, Provide], key: str, provider: Provide) -> None: - """Validate that a given provider has not been already defined under a different key.""" - for dependency_key, value in dependencies.items(): - if provider == value: - raise ImproperlyConfiguredException( - f"Provider for key {key} is already defined under the different key {dependency_key}. " - f"If you wish to override a provider, it must have the same key." - ) + for guard in self.guards: + await guard(connection, self) - def on_registration(self, app: Litestar, route: BaseRoute) -> None: + def on_registration(self, route: BaseRoute, app: Litestar) -> None: """Called once per handler when the app object is instantiated. Args: - app: The :class:`Litestar<.app.Litestar>` app object. route: The route this handler is being registered on + app: The application instance Returns: None """ + + self._dto = self._resolve_data_dto(app=app) + self._return_dto = self._resolve_return_dto(app=app, data_dto=self._dto) + self._validate_handler_function() - self.resolve_dependencies() - self.resolve_guards() - self.resolve_middleware() - self.resolve_opts() - self.resolve_data_dto() - self.resolve_return_dto() + self._finalize_dependencies(app=app) def _validate_handler_function(self) -> None: """Validate the route handler function once set by inspecting its return annotations.""" - if ( - self.parsed_data_field is not None - and self.parsed_data_field.is_subclass_of(DTOData) - and not self.resolve_data_dto() - ): + if self.parsed_data_field is not None and self.parsed_data_field.is_subclass_of(DTOData) and not self.data_dto: raise ImproperlyConfiguredException( f"Handler function {self.handler_name} has a data parameter that is a subclass of DTOData but no " "DTO has been registered for it." @@ -567,7 +548,7 @@ def _create_kwargs_model( return KwargsModel.create_for_signature_model( signature_model=self.signature_model, parsed_signature=self.parsed_fn_signature, - dependencies=self.resolve_dependencies(), + dependencies=self.dependencies, path_parameters=set(path_parameters), - layered_parameters=self.resolve_layered_parameters(), + layered_parameters=self.parameter_field_definitions, ) diff --git a/litestar/handlers/http_handlers/_utils.py b/litestar/handlers/http_handlers/_utils.py index d24b238f3b..26ca414f0e 100644 --- a/litestar/handlers/http_handlers/_utils.py +++ b/litestar/handlers/http_handlers/_utils.py @@ -16,6 +16,7 @@ from litestar.connection import Request from litestar.datastructures import Cookie, ResponseHeader from litestar.types import AfterRequestHookHandler, ASGIApp, AsyncAnyCallable, Method, TypeEncodersMap + from litestar.types.asgi_types import HttpMethodName from litestar.typing import FieldDefinition __all__ = ( @@ -160,7 +161,7 @@ async def handler( return handler -def normalize_http_method(http_methods: HttpMethod | Method | Sequence[HttpMethod | Method]) -> set[Method]: +def normalize_http_method(http_methods: Method | Sequence[Method]) -> set[HttpMethodName]: """Normalize HTTP method(s) into a set of upper-case method names. Args: @@ -180,10 +181,10 @@ def normalize_http_method(http_methods: HttpMethod | Method | Sequence[HttpMetho raise ValidationException(f"Invalid HTTP method: {method_name}") output.add(method_name) - return cast("set[Method]", output) + return cast("set[HttpMethodName]", output) -def get_default_status_code(http_methods: set[Method]) -> int: +def get_default_status_code(http_methods: set[HttpMethodName]) -> int: """Return the default status code for a given set of HTTP methods. Args: diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index b9c11461e5..5b4f2905f8 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -1,14 +1,13 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, AnyStr, Iterable, Mapping, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, AnyStr, Awaitable, Callable, Iterable, Mapping, Sequence, TypedDict, cast from msgspec.msgpack import decode as _decode_msgpack_plain from litestar._layers.utils import narrow_response_cookies, narrow_response_headers from litestar.connection import Request -from litestar.datastructures import CacheControlHeader, ETag, FormMultiDict -from litestar.datastructures.cookie import Cookie +from litestar.datastructures import CacheControlHeader, ETag, FormMultiDict, Header from litestar.datastructures.response_header import ResponseHeader from litestar.enums import HttpMethod, MediaType from litestar.exceptions import ( @@ -52,25 +51,35 @@ TypeEncodersMap, ) from litestar.types.builtin_types import NoneType +from litestar.utils import deprecated as litestar_deprecated from litestar.utils import ensure_async_callable +from litestar.utils.empty import value_or_default from litestar.utils.predicates import is_async_callable, is_class_and_subclass from litestar.utils.scope.state import ScopeState from litestar.utils.warnings import warn_implicit_sync_to_thread, warn_sync_to_thread_with_async_callable if TYPE_CHECKING: - from typing import Any, Awaitable, Callable + from typing import Any + from litestar import Litestar, Router from litestar._kwargs import KwargsModel from litestar._kwargs.cleanup import DependencyCleanupGroup - from litestar.app import Litestar from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.config.response_cache import CACHE_FOREVER + from litestar.datastructures.cookie import Cookie from litestar.dto import AbstractDTO from litestar.openapi.datastructures import ResponseSpec from litestar.openapi.spec import SecurityRequirement from litestar.routes import BaseRoute - from litestar.types.callable_types import AsyncAnyCallable, OperationIDCreator - from litestar.types.composite_types import TypeDecodersSequence + from litestar.types.callable_types import ( + AsyncAfterRequestHookHandler, + AsyncAfterResponseHookHandler, + AsyncAnyCallable, + AsyncBeforeRequestHookHandler, + OperationIDCreator, + ) + from litestar.types.composite_types import ParametersMap, TypeDecodersSequence + from litestar.typing import FieldDefinition __all__ = ("HTTPRouteHandler",) @@ -82,16 +91,14 @@ class ResponseHandlerMap(TypedDict): class HTTPRouteHandler(BaseRouteHandler): __slots__ = ( + "_default_response_handler", + "_include_in_schema", "_kwargs_models", - "_resolved_after_response", - "_resolved_before_request", - "_resolved_include_in_schema", - "_resolved_request_class", - "_resolved_request_max_body_size", - "_resolved_response_class", - "_resolved_security", - "_resolved_tags", - "_response_handler_mapping", + "_request_class", + "_request_max_body_size", + "_response_class", + "_response_type_handler", + "_sync_to_thread", "after_request", "after_response", "background", @@ -106,14 +113,10 @@ class HTTPRouteHandler(BaseRouteHandler): "etag", "has_sync_callable", "http_methods", - "include_in_schema", "media_type", "operation_class", "operation_id", "raises", - "request_class", - "request_max_body_size", - "response_class", "response_cookies", "response_description", "response_headers", @@ -121,9 +124,7 @@ class HTTPRouteHandler(BaseRouteHandler): "security", "status_code", "summary", - "sync_to_thread", "tags", - "template_name", ) def __init__( @@ -143,7 +144,7 @@ def __init__( etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, guards: Sequence[Guard] | None = None, - http_method: HttpMethod | Method | Sequence[HttpMethod | Method], + http_method: Method | Sequence[Method], media_type: MediaType | str | None = None, middleware: Sequence[Middleware] | None = None, name: str | None = None, @@ -173,6 +174,7 @@ def __init__( tags: Sequence[str] | None = None, type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: """Route handler for HTTP routes. @@ -240,6 +242,7 @@ def __init__( include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. This list should describe all exceptions raised within the route handler's function/method. The Litestar ValidationException will be added automatically for the schema if any validation is involved. @@ -249,6 +252,7 @@ def __init__( tags: A sequence of string tags that will be appended to the OpenAPI schema. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions **kwargs: Any additional kwarg - will be set in the opt dictionary. """ if not http_method: @@ -256,16 +260,20 @@ def __init__( self.http_methods = normalize_http_method(http_methods=http_method) self.status_code = status_code or get_default_status_code(http_methods=self.http_methods) + self._sync_to_thread = sync_to_thread - if has_sync_callable := not is_async_callable(fn): + if not is_async_callable(fn): if sync_to_thread is None: warn_implicit_sync_to_thread(fn, stacklevel=3) elif sync_to_thread is not None: warn_sync_to_thread_with_async_callable(fn, stacklevel=3) + has_sync_callable = not is_async_callable(fn) + if has_sync_callable and sync_to_thread: fn = ensure_async_callable(fn) has_sync_callable = False + self.has_sync_callable = has_sync_callable super().__init__( fn=fn, @@ -281,51 +289,117 @@ def __init__( signature_namespace=signature_namespace, type_decoders=type_decoders, type_encoders=type_encoders, + parameters=parameters, **kwargs, ) - self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None + self.after_request: AsyncAfterRequestHookHandler | None = ( + ensure_async_callable(after_request) if after_request else None # type: ignore[assignment] + ) + self.after_response: AsyncAfterResponseHookHandler | None = ( + ensure_async_callable(after_response) if after_response else None + ) self.background = background - self.before_request = ensure_async_callable(before_request) if before_request else None + self.before_request: AsyncBeforeRequestHookHandler | None = ( + ensure_async_callable(before_request) if before_request else None + ) self.cache = cache self.cache_control = cache_control self.cache_key_builder = cache_key_builder self.etag = etag self.media_type: MediaType | str = media_type or "" - self.request_class = request_class - self.response_class = response_class - self.response_cookies: Sequence[Cookie] | None = narrow_response_cookies(response_cookies) - self.response_headers: Sequence[ResponseHeader] | None = narrow_response_headers(response_headers) - self.request_max_body_size = request_max_body_size + self._request_class = request_class + self._response_class = response_class + self.response_cookies = frozenset(narrow_response_cookies(response_cookies if response_cookies else ())) + self.response_headers: frozenset[ResponseHeader] = self._resolve_response_headers( + response_headers, + self.etag, + self.cache_control, + ) + self._request_max_body_size = request_max_body_size - self.has_sync_callable = has_sync_callable # OpenAPI related attributes self.content_encoding = content_encoding self.content_media_type = content_media_type self.deprecated = deprecated self.description = description - self.include_in_schema = include_in_schema + self._include_in_schema = include_in_schema self.operation_class = operation_class self.operation_id = operation_id self.raises = raises self.response_description = response_description self.summary = summary - self.tags = tags - self.security = security + self.tags = frozenset(tags) if tags else frozenset() + self.security = tuple(security) if security else () self.responses = responses # memoized attributes, defaulted to Empty - self._resolved_after_response: AsyncAnyCallable | None | EmptyType = Empty - self._resolved_before_request: AsyncAnyCallable | None | EmptyType = Empty - self._response_handler_mapping: ResponseHandlerMap = {"default_handler": Empty, "response_type_handler": Empty} - self._resolved_include_in_schema: bool | EmptyType = Empty - self._resolved_response_class: type[Response] | EmptyType = Empty - self._resolved_request_class: type[Request] | EmptyType = Empty - self._resolved_security: list[SecurityRequirement] | EmptyType = Empty - self._resolved_tags: list[str] | EmptyType = Empty self._kwargs_models: dict[tuple[str, ...], KwargsModel] = {} - self._resolved_request_max_body_size: int | EmptyType | None = Empty + self._default_response_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType = Empty + self._response_type_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType = Empty + + def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]: + merge_opts = super()._get_merge_opts(others) + + # these only exist on the handler, and therefore don't need merging + merge_opts.update( + background=self.background, + http_method=tuple(self.http_methods), + cache=self.cache, + media_type=self.media_type, + status_code=self.status_code, + # OpenAPI related attributes + content_encoding=self.content_encoding, + content_media_type=self.content_media_type, + deprecated=self.deprecated, + description=self.description, + operation_class=self.operation_class, + operation_id=self.operation_id, + raises=self.raises, + response_description=self.response_description, + responses=self.responses, + summary=self.summary, + sync_to_thread=False if self.has_sync_callable else None, + cache_key_builder=self.cache_key_builder, + ) + other: HTTPRouteHandler | Router + for other in (self, *others): # type: ignore[assignment] + merge_opts["after_response"] = merge_opts.get("after_response") or other.after_response + merge_opts["after_request"] = merge_opts.get("after_request") or other.after_request + merge_opts["before_request"] = merge_opts.get("before_request") or other.before_request + merge_opts["cache_control"] = merge_opts.get("cache_control") or other.cache_control + merge_opts["etag"] = merge_opts.get("etag") or other.etag + merge_opts["response_cookies"] = (*merge_opts.get("response_cookies", ()), *other.response_cookies) + merge_opts["response_headers"] = (*other.response_headers, *merge_opts.get("response_headers", ())) + merge_opts["security"] = (*other.security, *merge_opts.get("security", ())) + merge_opts["tags"] = (*other.tags, *merge_opts.get("tags", ())) + + # these are all properties which return a safe default if the corresponding + # config value is not set, so we only merge the router configs and determine + # the final value after that. + # ideally, a route handler wouldn't be able to take these in as optional, + # but that would require another 'finalized' route handler, which differs + # from the config object. this is planned for version 4 + if other is not self: + merge_opts["request_class"] = merge_opts.get("request_class") or other.request_class + merge_opts["response_class"] = merge_opts.get("response_class") or other.response_class + merge_opts["include_in_schema"] = value_or_default( + merge_opts.get("include_in_schema", Empty), other.include_in_schema + ) + + merge_opts["request_class"] = self._request_class or merge_opts.get("request_class") + merge_opts["response_class"] = self._response_class or merge_opts.get("response_class") + merge_opts["request_max_body_size"] = value_or_default( + self._request_max_body_size, + next((o.request_max_body_size for o in others if o.request_max_body_size is not Empty), Empty), + ) + merge_opts["include_in_schema"] = value_or_default( + self._include_in_schema, merge_opts.get("include_in_schema", Empty) + ) + + return merge_opts + + @litestar_deprecated("3.0", removal_in="4.0", alternative=".request_class property") def resolve_request_class(self) -> type[Request]: """Return the closest custom Request class in the owner graph or the default Request class. @@ -335,14 +409,13 @@ def resolve_request_class(self) -> type[Request]: The default :class:`Request <.connection.Request>` class for the route handler. """ - if self._resolved_request_class is Empty: - self._resolved_request_class = next( - (layer.request_class for layer in reversed(self.ownership_layers) if layer.request_class is not None), - Request, - ) + return self.request_class - return cast("type[Request]", self._resolved_request_class) + @property + def request_class(self) -> type[Request]: + return self._request_class or Request + @litestar_deprecated("3.0", removal_in="4.0", alternative=".response_class property") def resolve_response_class(self) -> type[Response]: """Return the closest custom Response class in the owner graph or the default Response class. @@ -351,61 +424,46 @@ def resolve_response_class(self) -> type[Response]: Returns: The default :class:`Response <.response.Response>` class for the route handler. """ - if self._resolved_response_class is Empty: - self._resolved_response_class = next( - (layer.response_class for layer in reversed(self.ownership_layers) if layer.response_class is not None), - Response, - ) + return self.response_class - return cast("type[Response]", self._resolved_response_class) + @property + def response_class(self) -> type[Response]: + return self._response_class or Response + @litestar_deprecated("3.0", removal_in="4.0", alternative=".response_headers attribute") def resolve_response_headers(self) -> frozenset[ResponseHeader]: + return self.response_headers + + @staticmethod + def _resolve_response_headers( + response_headers: ResponseHeaders | None, + *extra_headers: Header | None, + ) -> frozenset[ResponseHeader]: """Return all header parameters in the scope of the handler function. Returns: A dictionary mapping keys to :class:`ResponseHeader <.datastructures.ResponseHeader>` instances. """ - resolved_response_headers: dict[str, ResponseHeader] = {} - - for layer in self.ownership_layers: - if layer_response_headers := layer.response_headers: - if isinstance(layer_response_headers, Mapping): - # this can't happen unless you manually set response_headers on an instance, which would result in a - # type-checking error on everything but the controller. We cover this case nevertheless - resolved_response_headers.update( - {name: ResponseHeader(name=name, value=value) for name, value in layer_response_headers.items()} - ) - else: - resolved_response_headers.update({h.name: h for h in layer_response_headers}) - for extra_header in ("cache_control", "etag"): - if header_model := getattr(layer, extra_header, None): - resolved_response_headers[header_model.HEADER_NAME] = ResponseHeader( - name=header_model.HEADER_NAME, - value=header_model.to_header(), - documentation_only=header_model.documentation_only, - ) + resolved_response_headers: dict[str, ResponseHeader] = ( + {h.name: h for h in narrow_response_headers(response_headers)} if response_headers else {} + ) + + for extra_header in extra_headers: + if extra_header is None: + continue + resolved_response_headers[extra_header.HEADER_NAME] = ResponseHeader( + name=extra_header.HEADER_NAME, + value=extra_header.to_header(), + documentation_only=extra_header.documentation_only, + ) return frozenset(resolved_response_headers.values()) + @litestar_deprecated("3.0", removal_in="4.0", alternative=".response_cookies attribute") def resolve_response_cookies(self) -> frozenset[Cookie]: - """Return a list of Cookie instances. Filters the list to ensure each cookie key is unique. - - Returns: - A list of :class:`Cookie <.datastructures.Cookie>` instances. - """ - response_cookies: set[Cookie] = set() - for layer in reversed(self.ownership_layers): - if layer_response_cookies := layer.response_cookies: - if isinstance(layer_response_cookies, Mapping): - # this can't happen unless you manually set response_cookies on an instance, which would result in a - # type-checking error on everything but the controller. We cover this case nevertheless - response_cookies.update( - {Cookie(key=key, value=value) for key, value in layer_response_cookies.items()} - ) - else: - response_cookies.update(cast("set[Cookie]", layer_response_cookies)) - return frozenset(response_cookies) + return self.response_cookies + @litestar_deprecated("3.0", removal_in="4.0", alternative=".before_request attribute") def resolve_before_request(self) -> AsyncAnyCallable | None: """Resolve the before_handler handler by starting from the route handler and moving up. @@ -415,11 +473,9 @@ def resolve_before_request(self) -> AsyncAnyCallable | None: Returns: An optional :class:`before request lifecycle hook handler <.types.BeforeRequestHookHandler>` """ - if self._resolved_before_request is Empty: - before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request] - self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None - return cast("AsyncAnyCallable | None", self._resolved_before_request) + return self.before_request + @litestar_deprecated("3.0", removal_in="4.0", alternative=".after_response attribute") def resolve_after_response(self) -> AsyncAnyCallable | None: """Resolve the after_response handler by starting from the route handler and moving up. @@ -429,16 +485,9 @@ def resolve_after_response(self) -> AsyncAnyCallable | None: Returns: An optional :class:`after response lifecycle hook handler <.types.AfterResponseHookHandler>` """ - if self._resolved_after_response is Empty: - after_response_handlers: list[AsyncAnyCallable] = [ - layer.after_response # type: ignore[misc] - for layer in self.ownership_layers - if layer.after_response - ] - self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None - - return cast("AsyncAnyCallable | None", self._resolved_after_response) + return self.after_response + @litestar_deprecated("3.0", removal_in="4.0", alternative=".include_in_schema property") def resolve_include_in_schema(self) -> bool: """Resolve the 'include_in_schema' property by starting from the route handler and moving up. @@ -448,15 +497,14 @@ def resolve_include_in_schema(self) -> bool: Returns: bool: The resolved 'include_in_schema' property. """ - if self._resolved_include_in_schema is Empty: - include_in_schemas = [ - i.include_in_schema for i in self.ownership_layers if isinstance(i.include_in_schema, bool) - ] - self._resolved_include_in_schema = include_in_schemas[-1] if include_in_schemas else True + return self.include_in_schema - return self._resolved_include_in_schema + @property + def include_in_schema(self) -> bool: + return self._include_in_schema if self._include_in_schema is not Empty else True - def resolve_security(self) -> list[SecurityRequirement]: + @litestar_deprecated("3.0", removal_in="4.0", alternative=".security attribute") + def resolve_security(self) -> tuple[SecurityRequirement, ...]: """Resolve the security property by starting from the route handler and moving up. Security requirements are additive, so the security requirements of the route handler are the sum of all @@ -465,139 +513,45 @@ def resolve_security(self) -> list[SecurityRequirement]: Returns: list[SecurityRequirement]: The resolved security property. """ - if self._resolved_security is Empty: - self._resolved_security = [] - for layer in self.ownership_layers: - if isinstance(layer.security, Sequence): - self._resolved_security.extend(layer.security) - - return self._resolved_security + return self.security - def resolve_tags(self) -> list[str]: + @litestar_deprecated("3.0", removal_in="4.0", alternative=".tags attribute") + def resolve_tags(self) -> frozenset[str]: """Resolve the tags property by starting from the route handler and moving up. Tags are additive, so the tags of the route handler are the sum of all tags of the ownership layers. - - Returns: - list[str]: A sorted list of unique tags. """ - if self._resolved_tags is Empty: - tag_set = set() - for layer in self.ownership_layers: - for tag in layer.tags or []: - tag_set.add(tag) - self._resolved_tags = sorted(tag_set) - - return self._resolved_tags + return self.tags + @litestar_deprecated("3.0", removal_in="4.0", alternative=".request_max_body_size property") def resolve_request_max_body_size(self) -> int | None: - if (resolved_limits := self._resolved_request_max_body_size) is not Empty: - return resolved_limits - - max_body_size = self._resolved_request_max_body_size = next( # pyright: ignore - ( - max_body_size - for layer in reversed(self.ownership_layers) - if (max_body_size := layer.request_max_body_size) is not Empty - ), - Empty, - ) - if max_body_size is Empty: - raise ImproperlyConfiguredException( - "'request_max_body_size' set to 'Empty' on all layers. To omit a limit, " - "set 'request_max_body_size=None'" - ) - return max_body_size - - def get_response_handler(self, is_response_type_data: bool = False) -> Callable[..., Awaitable[ASGIApp]]: - """Resolve the response_handler function for the route handler. - - This method is memoized so the computation occurs only once. - - Args: - is_response_type_data: Whether to return a handler for 'Response' instances. + return self.request_max_body_size - Returns: - Async Callable to handle an HTTP Request - """ - if self._response_handler_mapping["default_handler"] is Empty: - after_request_handlers: list[AsyncAnyCallable] = [ - layer.after_request # type: ignore[misc] - for layer in self.ownership_layers - if layer.after_request - ] - after_request = cast( - "AfterRequestHookHandler | None", - after_request_handlers[-1] if after_request_handlers else None, - ) - - media_type = self.media_type.value if isinstance(self.media_type, Enum) else self.media_type - response_class = self.resolve_response_class() - headers = self.resolve_response_headers() - cookies = self.resolve_response_cookies() - type_encoders = self.resolve_type_encoders() + @property + def request_max_body_size(self) -> int | None: + return value_or_default(self._request_max_body_size, None) # pyright: ignore - return_type = self.parsed_fn_signature.return_type - return_annotation = return_type.annotation + def on_registration(self, route: BaseRoute, app: Litestar) -> None: + super().on_registration(route=route, app=app) - self._response_handler_mapping["response_type_handler"] = response_type_handler = create_response_handler( - after_request=after_request, - background=self.background, - cookies=cookies, - headers=headers, - media_type=media_type, - status_code=self.status_code, - type_encoders=type_encoders, + if self._request_max_body_size is Empty: + raise ImproperlyConfiguredException( + "'request_max_body_size' set to 'Empty' on all layers. To omit a limit, " + "set 'request_max_body_size=None'" ) - if return_type.is_subclass_of(Response): - self._response_handler_mapping["default_handler"] = response_type_handler - elif is_async_callable(return_annotation) or return_annotation is ASGIApp: - self._response_handler_mapping["default_handler"] = create_generic_asgi_response_handler( - after_request=after_request - ) - else: - self._response_handler_mapping["default_handler"] = create_data_handler( - after_request=after_request, - background=self.background, - cookies=cookies, - headers=headers, - media_type=media_type, - response_class=response_class, - status_code=self.status_code, - type_encoders=type_encoders, - ) - - return cast( - "Callable[[Any], Awaitable[ASGIApp]]", - self._response_handler_mapping["response_type_handler"] - if is_response_type_data - else self._response_handler_mapping["default_handler"], - ) - - async def to_response(self, data: Any, request: Request) -> ASGIApp: - """Return a :class:`Response <.response.Response>` from the handler by resolving and calling it. - - Args: - data: Either an instance of a :class:`Response <.response.Response>`, - a Response instance or an arbitrary value. - request: A :class:`Request <.connection.Request>` instance - - Returns: - A Response instance - """ - if return_dto_type := self.resolve_return_dto(): - data = return_dto_type(request).data_to_encodable_type(data) - - response_handler = self.get_response_handler(is_response_type_data=isinstance(data, Response)) - return await response_handler(data=data, request=request) - - def on_registration(self, app: Litestar, route: BaseRoute) -> None: - super().on_registration(app, route=route) - self.resolve_after_response() - self.resolve_include_in_schema() - self._get_kwargs_model_for_route(route.path_parameters) + self._default_response_handler, self._response_type_handler = self._create_response_handlers( + media_type=self.media_type, + response_class=self.response_class, + cookies=self.response_cookies, + headers=self.response_headers, + type_encoders=self.type_encoders, + return_type=self.parsed_fn_signature.return_type, + status_code=self.status_code, + background=self.background, + after_request=self.after_request, + ) def _get_kwargs_model_for_route(self, path_parameters: Iterable[str]) -> KwargsModel: key = tuple(path_parameters) @@ -613,7 +567,7 @@ def _validate_handler_function(self) -> None: if return_type.annotation is Empty: raise ImproperlyConfiguredException( - f"A return value of a route handler function {self} should be type annotated. " + f"Missing return type annotation for route handler function {self!r}. " "If your function doesn't return a value, annotate it as returning 'None'." ) @@ -657,6 +611,48 @@ def _validate_handler_function(self) -> None: "processed request data, use the 'data' parameter." ) + @staticmethod + def _create_response_handlers( + *, + media_type: MediaType | str, + response_class: type[Response], + headers: frozenset[ResponseHeader], + cookies: frozenset[Cookie], + type_encoders: TypeEncodersMap, + return_type: FieldDefinition, + status_code: int, + background: BackgroundTask | BackgroundTasks | None, + after_request: AfterRequestHookHandler | None, + ) -> tuple[Callable[..., Awaitable[ASGIApp]], Callable[..., Awaitable[ASGIApp]]]: + media_type = media_type.value if isinstance(media_type, Enum) else media_type + return_annotation = return_type.annotation + + response_type_handler = create_response_handler( + after_request=after_request, + background=background, + cookies=cookies, + headers=headers, + media_type=media_type, + status_code=status_code, + type_encoders=type_encoders, + ) + + if is_async_callable(return_annotation) or return_annotation is ASGIApp: + default_handler = create_generic_asgi_response_handler(after_request=after_request) + else: + default_handler = create_data_handler( + after_request=after_request, + background=background, + cookies=cookies, + headers=headers, + media_type=media_type, + response_class=response_class, + status_code=status_code, + type_encoders=type_encoders, + ) + + return default_handler, response_type_handler + async def handle(self, connection: Request[Any, Any, Any]) -> None: """ASGI app that creates a :class:`~.connection.Request` from the passed in args, determines which handler function to call and then handles the call. @@ -670,7 +666,7 @@ async def handle(self, connection: Request[Any, Any, Any]) -> None: None """ - if self.resolve_guards(): + if self.guards: await self.authorize_connection(connection=connection) try: @@ -678,7 +674,7 @@ async def handle(self, connection: Request[Any, Any, Any]) -> None: await response(connection.scope, connection.receive, connection.send) - if after_response_handler := self.resolve_after_response(): + if after_response_handler := self.after_response: await after_response_handler(connection) finally: if (form_data := ScopeState.from_scope(connection.scope).form) is not Empty: @@ -700,22 +696,13 @@ async def _get_response_for_request( Returns: An instance of Response or a compatible ASGIApp or a subclass of it """ - if self.cache and (response := await self._get_cached_response(request=request)): - return response - - return await self._call_handler_function(request=request) + if self.cache and (cached_response := await self._get_cached_response(request=request)): + return cached_response - async def _call_handler_function(self, request: Request) -> ASGIApp: - """Call the before request handlers, retrieve any data required for the route handler, and call the route - handler's ``to_response`` method. - - This is wrapped in a try except block - and if an exception is raised, - it tries to pass it to an appropriate exception handler - if defined. - """ response_data: Any = None cleanup_group: DependencyCleanupGroup | None = None - if before_request_handler := self.resolve_before_request(): + if before_request_handler := self.before_request: response_data = await before_request_handler(request) if not response_data: @@ -732,22 +719,19 @@ async def _get_response_data(self, request: Request) -> tuple[Any, DependencyCle """Determine what kwargs are required for the given route handler's ``fn`` and calls it.""" parsed_kwargs: dict[str, Any] = {} cleanup_group: DependencyCleanupGroup | None = None - parameter_model = self._get_kwargs_model_for_route(request.scope["path_params"].keys()) + kwargs_models_model = self._get_kwargs_model_for_route(request.scope["path_params"].keys()) - if parameter_model.has_kwargs and self.signature_model: + if kwargs_models_model.has_kwargs and self.signature_model: try: - kwargs = await parameter_model.to_kwargs(connection=request) + kwargs = await kwargs_models_model.to_kwargs(connection=request) except SerializationException as e: raise ClientException(str(e)) from e - if "data" in kwargs: - data = kwargs["data"] - - if data is Empty: - del kwargs["data"] + if "data" in kwargs and kwargs["data"] is Empty: + del kwargs["data"] - if parameter_model.dependency_batches: - cleanup_group = await parameter_model.resolve_dependencies(request, kwargs) + if kwargs_models_model.dependency_batches: + cleanup_group = await kwargs_models_model.resolve_dependencies(request, kwargs) parsed_kwargs = self.signature_model.parse_values_from_connection_kwargs( connection=request, @@ -791,3 +775,24 @@ async def cached_response(scope: Scope, receive: Receive, send: Send) -> None: await send(message) return cached_response + + async def to_response(self, data: Any, request: Request) -> ASGIApp: + """Return a :class:`Response <.response.Response>` from the handler by resolving and calling it. + + Args: + data: Either an instance of a :class:`Response <.response.Response>`, + a Response instance or an arbitrary value. + request: A :class:`Request <.connection.Request>` instance + + Returns: + A Response instance + """ + if return_dto_type := self.return_dto: + data = return_dto_type(request).data_to_encodable_type(data) + + handler = cast( + Callable[..., Awaitable[ASGIApp]], + self._response_type_handler if isinstance(data, Response) else self._default_response_handler, + ) + + return await handler(data=data, request=request) diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 748531e801..735a6665c7 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -54,7 +54,7 @@ ) from litestar.types.callable_types import AnyCallable, OperationIDCreator -__all__ = ("delete", "get", "head", "patch", "post", "put") +__all__ = ("delete", "get", "head", "patch", "post", "put", "route") def route( @@ -1030,7 +1030,6 @@ def put( hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. handler_class: Route handler class instantiated by the decorator - **kwargs: Any additional kwarg - will be set in the opt dictionary. """ diff --git a/litestar/handlers/websocket_handlers/_utils.py b/litestar/handlers/websocket_handlers/_utils.py index bcd90ac1c2..fefef91b8c 100644 --- a/litestar/handlers/websocket_handlers/_utils.py +++ b/litestar/handlers/websocket_handlers/_utils.py @@ -20,7 +20,7 @@ def create_handle_receive(listener: WebsocketListenerRouteHandler) -> Callable[[WebSocket], Coroutine[Any, None, None]]: - if data_dto := listener.resolve_data_dto(): + if data_dto := listener.data_dto: async def handle_receive(socket: WebSocket) -> Any: received_data = await socket.receive_data(mode=listener._receive_mode) @@ -44,7 +44,7 @@ async def handle_receive(socket: WebSocket) -> Any: async def handle_receive(socket: WebSocket) -> Any: received_data = await socket.receive_data(mode=listener._receive_mode) - return decode_json(value=received_data, type_decoders=socket.route_handler.resolve_type_decoders()) + return decode_json(value=received_data, type_decoders=socket.route_handler.type_decoders) return handle_receive @@ -54,7 +54,7 @@ def create_handle_send( ) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]: json_encoder = JsonEncoder(enc_hook=listener.default_serializer) - if return_dto := listener.resolve_return_dto(): + if return_dto := listener.return_dto: async def handle_send(socket: WebSocket, data: Any) -> None: encoded_data = return_dto(socket).data_to_encodable_type(data) @@ -78,7 +78,7 @@ async def handle_send(socket: WebSocket, data: Any) -> None: class ListenerHandler: - __slots__ = ("_can_send_data", "_fn", "_listener", "_pass_socket") + __slots__ = ("_can_send_data", "_fn", "_listener", "_pass_socket", "_wrapped_fn") def __init__( self, @@ -88,6 +88,7 @@ def __init__( namespace: dict[str, Any], ) -> None: self._can_send_data = not parsed_signature.return_type.is_subclass_of(NoneType) + self._wrapped_fn = fn self._fn = ensure_async_callable(fn) self._listener = listener self._pass_socket = "socket" in parsed_signature.parameters diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 15ac6b3540..b21f8d3151 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -10,11 +10,10 @@ Dict, Mapping, Optional, - cast, + Sequence, overload, ) -from litestar._signature import SignatureModel from litestar.connection import WebSocket from litestar.exceptions import ImproperlyConfiguredException, WebSocketDisconnect from litestar.types import ( @@ -42,10 +41,11 @@ if TYPE_CHECKING: from typing import Coroutine - from litestar import Router + from litestar import Litestar, Router from litestar.dto import AbstractDTO + from litestar.routes import BaseRoute from litestar.types.asgi_types import WebSocketMode - from litestar.types.composite_types import TypeDecodersSequence + from litestar.types.composite_types import ParametersMap, TypeDecodersSequence __all__ = ("WebsocketListener", "WebsocketListenerRouteHandler", "websocket_listener") @@ -77,8 +77,8 @@ def __init__( dependencies: Dependencies | None = None, dto: type[AbstractDTO] | None | EmptyType = Empty, exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, - guards: list[Guard] | None = None, - middleware: list[Middleware] | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, receive_mode: WebSocketMode = "text", send_mode: WebSocketMode = "text", name: str | None = None, @@ -88,6 +88,7 @@ def __init__( type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: ... @@ -101,8 +102,8 @@ def __init__( dependencies: Dependencies | None = None, dto: type[AbstractDTO] | None | EmptyType = Empty, exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, - guards: list[Guard] | None = None, - middleware: list[Middleware] | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, receive_mode: WebSocketMode = "text", send_mode: WebSocketMode = "text", name: str | None = None, @@ -114,6 +115,7 @@ def __init__( type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: ... @@ -127,8 +129,8 @@ def __init__( dependencies: Dependencies | None = None, dto: type[AbstractDTO] | None | EmptyType = Empty, exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, - guards: list[Guard] | None = None, - middleware: list[Middleware] | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, receive_mode: WebSocketMode = "text", send_mode: WebSocketMode = "text", name: str | None = None, @@ -140,6 +142,7 @@ def __init__( type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: """Initialize ``WebsocketRouteHandler`` @@ -176,9 +179,10 @@ def __init__( type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. - **kwargs: Any additional kwarg - will be set in the opt dictionary. websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's default websocket class. + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions + **kwargs: Any additional kwarg - will be set in the opt dictionary. """ if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]): raise ImproperlyConfiguredException( @@ -195,9 +199,6 @@ def __init__( self.connection_accept_handler = connection_accept_handler self.on_accept = ensure_async_callable(on_accept) if on_accept else None self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None - self.type_decoders = type_decoders - self.type_encoders = type_encoders - self.websocket_class = websocket_class listener_dependencies = dict(dependencies or {}) @@ -226,11 +227,28 @@ def __init__( type_decoders=type_decoders, type_encoders=type_encoders, websocket_class=websocket_class, + parameters=parameters, **kwargs, ) - def _prepare_fn(self, fn: AnyCallable) -> ListenerHandler: - parsed_signature = ParsedSignature.from_fn(fn, self.resolve_signature_namespace()) + def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]: + merge_opts = super()._get_merge_opts(others) + merge_opts.update( + receive_mode=self._receive_mode, + send_mode=self._send_mode, + connection_lifespan=self._connection_lifespan, + connection_accept_handler=self.connection_accept_handler, + on_accept=self.on_accept, + on_disconnect=self.on_disconnect, + ) + return merge_opts + + def on_registration(self, route: BaseRoute, app: Litestar) -> None: + self.fn = self._prepare_fn() + super().on_registration(route, app) + + def _prepare_fn(self) -> ListenerHandler: + parsed_signature = ParsedSignature.from_fn(self.fn, self.signature_namespace) if "data" not in parsed_signature.parameters: raise ImproperlyConfiguredException("Websocket listeners must accept a 'data' parameter") @@ -246,36 +264,19 @@ def _prepare_fn(self, fn: AnyCallable) -> ListenerHandler: self._parsed_fn_signature = ParsedSignature.from_signature( create_handler_signature(parsed_signature.original_signature), fn_type_hints={ - **get_fn_type_hints(fn, namespace=self.resolve_signature_namespace()), - **get_fn_type_hints(ListenerHandler.__call__, namespace=self.resolve_signature_namespace()), + **get_fn_type_hints(self.fn, namespace=self.signature_namespace), + **get_fn_type_hints(ListenerHandler.__call__, namespace=self.signature_namespace), }, ) return ListenerHandler( - listener=self, fn=fn, parsed_signature=parsed_signature, namespace=self.resolve_signature_namespace() + listener=self, fn=self.fn, parsed_signature=parsed_signature, namespace=self.signature_namespace ) def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" # validation occurs in the call method - @property - def signature_model(self) -> type[SignatureModel]: - """Get the signature model for the route handler. - - Returns: - A signature model for the route handler. - - """ - if self._signature_model is Empty: - self._signature_model = SignatureModel.create( - dependency_name_set=self.dependency_name_set, - fn=cast("AnyCallable", self.fn), - parsed_signature=self.parsed_fn_signature, - type_decoders=self.resolve_type_decoders(), - ) - return self._signature_model - @asynccontextmanager async def default_connection_lifespan( self, @@ -365,19 +366,11 @@ class WebsocketListener(ABC): default websocket class. """ - def __init__(self, owner: Router) -> None: - """Initialize a WebsocketListener instance. - - Args: - owner: The :class:`Router <.router.Router>` instance that owns this listener. - """ - self._owner = owner - def to_handler(self) -> WebsocketListenerRouteHandler: on_accept = self.on_accept if self.on_accept != WebsocketListener.on_accept else None on_disconnect = self.on_disconnect if self.on_disconnect != WebsocketListener.on_disconnect else None - handler = WebsocketListenerRouteHandler( + return WebsocketListenerRouteHandler( dependencies=self.dependencies, dto=self.dto, exception_handlers=self.exception_handlers, @@ -397,8 +390,6 @@ def to_handler(self) -> WebsocketListenerRouteHandler: websocket_class=self.websocket_class, fn=self.on_receive, ) - handler.owner = self._owner - return handler def on_accept(self, *args: Any, **kwargs: Any) -> Any: """Called after a :class:`WebSocket <.connection.WebSocket>` connection diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index 5c108db429..ce4a81d3dc 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -1,24 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Mapping +from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence from litestar.connection import WebSocket from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import BaseRouteHandler -from litestar.types import AsyncAnyCallable, Empty +from litestar.types import AsyncAnyCallable, Empty, ParametersMap from litestar.types.builtin_types import NoneType +from litestar.utils import deprecated from litestar.utils.predicates import is_async_callable if TYPE_CHECKING: + from litestar import Litestar, Router from litestar._kwargs import KwargsModel from litestar._kwargs.cleanup import DependencyCleanupGroup - from litestar.app import Litestar from litestar.routes import BaseRoute from litestar.types import Dependencies, EmptyType, ExceptionHandler, Guard, Middleware class WebsocketRouteHandler(BaseRouteHandler): - __slots__ = ("_kwargs_model", "websocket_class") + __slots__ = ("_kwargs_model", "_websocket_class") def __init__( self, @@ -27,12 +28,13 @@ def __init__( fn: AsyncAnyCallable, dependencies: Dependencies | None = None, exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, - guards: list[Guard] | None = None, - middleware: list[Middleware] | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: dict[str, Any] | None = None, signature_namespace: Mapping[str, Any] | None = None, websocket_class: type[WebSocket] | None = None, + parameters: ParametersMap | None = None, **kwargs: Any, ) -> None: """Route handler for WebSocket routes. @@ -55,9 +57,10 @@ def __init__( type_encoders: A mapping of types to callables that transform them into types supported for serialization. websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's default websocket class. + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions **kwargs: Any additional kwarg - will be set in the opt dictionary. """ - self.websocket_class = websocket_class + self._websocket_class = websocket_class self._kwargs_model: KwargsModel | EmptyType = Empty super().__init__( @@ -70,9 +73,19 @@ def __init__( name=name, opt=opt, signature_namespace=signature_namespace, + parameters=parameters, **kwargs, ) + def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]: + merge_opts = super()._get_merge_opts(others) + merge_opts["websocket_class"] = self._websocket_class or next( + (o.websocket_class for o in others if o.websocket_class), None + ) + + return merge_opts + + @deprecated("3.0", removal_in="4.0", alternative=".websocket_class property") def resolve_websocket_class(self) -> type[WebSocket]: """Return the closest custom WebSocket class in the owner graph or the default Websocket class. @@ -81,10 +94,11 @@ def resolve_websocket_class(self) -> type[WebSocket]: Returns: The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. """ - return next( - (layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None), - WebSocket, - ) + return self.websocket_class + + @property + def websocket_class(self) -> type[WebSocket]: + return self._websocket_class or WebSocket def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" @@ -105,8 +119,8 @@ def _validate_handler_function(self) -> None: if not is_async_callable(self.fn): raise ImproperlyConfiguredException(f"{self}: WebSocket handler functions must be asynchronous") - def on_registration(self, app: Litestar, route: BaseRoute) -> None: - super().on_registration(app=app, route=route) + def on_registration(self, route: BaseRoute, app: Litestar) -> None: + super().on_registration(route=route, app=app) self._kwargs_model = self._create_kwargs_model(path_parameters=route.path_parameters) async def handle(self, connection: WebSocket[Any, Any, Any]) -> None: @@ -119,21 +133,21 @@ async def handle(self, connection: WebSocket[Any, Any, Any]) -> None: None """ - handler_parameter_model = self._kwargs_model - if handler_parameter_model is Empty: + handler_kwargs_model = self._kwargs_model + if handler_kwargs_model is Empty: raise ImproperlyConfiguredException("handler parameter model not defined") - if self.resolve_guards(): + if self.guards: await self.authorize_connection(connection=connection) parsed_kwargs: dict[str, Any] = {} cleanup_group: DependencyCleanupGroup | None = None - if handler_parameter_model.has_kwargs and self.signature_model: - parsed_kwargs = await handler_parameter_model.to_kwargs(connection=connection) + if handler_kwargs_model.has_kwargs: + parsed_kwargs = await handler_kwargs_model.to_kwargs(connection=connection) - if handler_parameter_model.dependency_batches: - cleanup_group = await handler_parameter_model.resolve_dependencies(connection, parsed_kwargs) + if handler_kwargs_model.dependency_batches: + cleanup_group = await handler_kwargs_model.resolve_dependencies(connection, parsed_kwargs) parsed_kwargs = self.signature_model.parse_values_from_connection_kwargs( connection=connection, kwargs=parsed_kwargs diff --git a/litestar/handlers/websocket_handlers/stream.py b/litestar/handlers/websocket_handlers/stream.py index ec3b603314..1defa0cfbe 100644 --- a/litestar/handlers/websocket_handlers/stream.py +++ b/litestar/handlers/websocket_handlers/stream.py @@ -211,11 +211,11 @@ class WebSocketStreamHandler(WebsocketRouteHandler): __slots__ = ("_ws_stream_options",) _ws_stream_options: _WebSocketStreamOptions - def on_registration(self, app: Litestar, route: BaseRoute) -> None: + def on_registration(self, route: BaseRoute, app: Litestar) -> None: self._ws_stream_options = self.opt["stream_options"] parsed_handler_signature = parsed_stream_fn_signature = ParsedSignature.from_fn( - self.fn, self.resolve_signature_namespace() + self.fn, self.signature_namespace ) if not parsed_stream_fn_signature.return_type.is_subclass_of(AsyncGenerator): @@ -254,7 +254,8 @@ def on_registration(self, app: Litestar, route: BaseRoute) -> None: self._parsed_return_field = parsed_stream_fn_signature.return_type.inner_types[0] json_encoder = JsonEncoder(enc_hook=self.default_serializer) - return_dto = self.resolve_return_dto() + self._dto = self._resolve_data_dto(app=app) + self._return_dto = return_dto = self._resolve_return_dto(app=app, data_dto=self._dto) # make sure the closure doesn't capture self._ws_stream / self send_mode: WebSocketMode = self._ws_stream_options.send_mode # pyright: ignore @@ -292,7 +293,7 @@ async def handler_fn(*args: Any, socket: WebSocket, **kw: Any) -> None: self.fn = handler_fn # pyright: ignore - super().on_registration(app, route) + super().on_registration(route, app) class _WebSocketStreamOptions: diff --git a/litestar/middleware/_internal/exceptions/middleware.py b/litestar/middleware/_internal/exceptions/middleware.py index 039666b294..1ffdcc4e78 100644 --- a/litestar/middleware/_internal/exceptions/middleware.py +++ b/litestar/middleware/_internal/exceptions/middleware.py @@ -167,7 +167,7 @@ async def handle_request_exception( request: Request[Any, Any, Any] = litestar_app.request_class(scope=scope, receive=receive, send=send) response = exception_handler(request, exc) route_handler: BaseRouteHandler | None = scope.get("route_handler") - type_encoders = route_handler.resolve_type_encoders() if route_handler else litestar_app.type_encoders + type_encoders = route_handler.type_encoders if route_handler else litestar_app.type_encoders await response.to_asgi_response(request=request, type_encoders=type_encoders)( scope=scope, receive=receive, send=send ) diff --git a/litestar/openapi/plugins.py b/litestar/openapi/plugins.py index a006381737..8871934c00 100644 --- a/litestar/openapi/plugins.py +++ b/litestar/openapi/plugins.py @@ -73,7 +73,7 @@ def render_json(request: Request, openapi_schema: dict[str, Any]) -> bytes: Returns: The rendered JSON. """ - return encode_json(openapi_schema, serializer=get_serializer(request.route_handler.resolve_type_encoders())) + return encode_json(openapi_schema, serializer=get_serializer(request.route_handler.type_encoders)) @abstractmethod def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: @@ -186,7 +186,7 @@ def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: # UNSET value (possible if the examples are being generated for a partial DTO model which makes # every type a union with UNSET) are stripped out. openapi_schema = msgspec.to_builtins( - openapi_schema, enc_hook=get_serializer(request.route_handler.resolve_type_encoders()) + openapi_schema, enc_hook=get_serializer(request.route_handler.type_encoders) ) return yaml.dump(openapi_schema, default_flow_style=False).encode("utf-8") diff --git a/litestar/router.py b/litestar/router.py index b5cf7fe1dd..d4ef005426 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -1,19 +1,11 @@ from __future__ import annotations -from collections import defaultdict -from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Any, Mapping, Sequence from litestar._layers.utils import narrow_response_cookies, narrow_response_headers -from litestar.controller import Controller from litestar.exceptions import ImproperlyConfiguredException -from litestar.handlers.asgi_handlers import ASGIRouteHandler -from litestar.handlers.http_handlers import HTTPRouteHandler -from litestar.handlers.http_handlers._options import create_options_handler -from litestar.handlers.websocket_handlers import WebsocketListener, WebsocketRouteHandler -from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute from litestar.types.empty import Empty -from litestar.utils import find_index, is_class_and_subclass, join_paths, normalize_path, unique +from litestar.utils import normalize_path from litestar.utils.signature import add_types_to_signature_namespace from litestar.utils.sync import ensure_async_callable @@ -26,7 +18,6 @@ from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement from litestar.response import Response - from litestar.routes import BaseRoute from litestar.types import ( AfterRequestHookHandler, AfterResponseHookHandler, @@ -37,10 +28,9 @@ Middleware, ParametersMap, ResponseCookies, - RouteHandlerMapItem, - RouteHandlerType, TypeEncodersMap, ) + from litestar.types.callable_types import AsyncAfterRequestHookHandler, AsyncAfterResponseHookHandler from litestar.types.composite_types import Dependencies, ResponseHeaders, TypeDecodersSequence from litestar.types.empty import EmptyType @@ -64,7 +54,6 @@ class Router: "include_in_schema", "middleware", "opt", - "owner", "parameters", "path", "registered_route_handler_ids", @@ -74,7 +63,7 @@ class Router: "response_cookies", "response_headers", "return_dto", - "routes", + "route_handlers", "security", "signature_namespace", "tags", @@ -172,165 +161,44 @@ def __init__( all route handlers, controllers and other routers associated with the router instance. """ - self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None + self.after_request: AsyncAfterRequestHookHandler | None = ( + ensure_async_callable(after_request) if after_request else None # type: ignore[assignment] + ) + self.after_response: AsyncAfterResponseHookHandler | None = ( + ensure_async_callable(after_response) if after_response else None + ) self.before_request = ensure_async_callable(before_request) if before_request else None self.cache_control = cache_control self.dto = dto self.etag = etag self.dependencies = dict(dependencies or {}) self.exception_handlers = dict(exception_handlers or {}) - self.guards = list(guards or []) + self.guards = tuple(guards or []) self.include_in_schema = include_in_schema self.middleware = list(middleware or []) self.opt = dict(opt or {}) - self.owner: Router | None = None self.parameters = dict(parameters or {}) self.path = normalize_path(path) self.request_class = request_class self.response_class = response_class - self.response_cookies = narrow_response_cookies(response_cookies) - self.response_headers = narrow_response_headers(response_headers) + self.response_cookies = narrow_response_cookies(response_cookies) if response_cookies else () + self.response_headers = narrow_response_headers(response_headers) if response_headers else () self.return_dto = return_dto - self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] self.security = list(security or []) self.signature_namespace = add_types_to_signature_namespace( signature_types or [], dict(signature_namespace or {}) ) self.tags = list(tags or []) self.registered_route_handler_ids: set[int] = set() - self.type_encoders = dict(type_encoders) if type_encoders is not None else None - self.type_decoders = list(type_decoders) if type_decoders is not None else None + self.type_encoders = dict(type_encoders) if type_encoders is not None else {} + self.type_decoders = list(type_decoders) if type_decoders is not None else () self.websocket_class = websocket_class self.request_max_body_size = request_max_body_size - for route_handler in route_handlers or []: - self.register(value=route_handler) - - def register(self, value: ControllerRouterHandler) -> list[BaseRoute]: - """Register a Controller, Route instance or RouteHandler on the router. - - Args: - value: a subclass or instance of Controller, an instance of :class:`Router` or a function/method that has - been decorated by any of the routing decorators, e.g. :class:`get <.handlers.get>`, - :class:`post <.handlers.post>`. - - Returns: - Collection of handlers added to the router. - """ - validated_value = self._validate_registration_value(value) - - routes: list[BaseRoute] = [] - - for route_path, handlers_map in self.get_route_handler_map(value=validated_value).items(): - path = join_paths([self.path, route_path]) - if http_handlers := unique( - [handler for handler in handlers_map.values() if isinstance(handler, HTTPRouteHandler)] - ): - if existing_handlers := unique( - [ - handler - for handler in self.route_handler_method_map.get(path, {}).values() - if isinstance(handler, HTTPRouteHandler) - ] - ): - http_handlers.extend(existing_handlers) - existing_route_index = find_index(self.routes, lambda x: x.path == path) # noqa: B023 - - if existing_route_index == -1: # pragma: no cover - raise ImproperlyConfiguredException("unable to find_index existing route index") - - route: WebSocketRoute | ASGIRoute | HTTPRoute = HTTPRoute( - path=path, - route_handlers=_maybe_add_options_handler(path, http_handlers), - ) - self.routes[existing_route_index] = route - else: - route = HTTPRoute(path=path, route_handlers=_maybe_add_options_handler(path, http_handlers)) - self.routes.append(route) - - routes.append(route) - - if websocket_handler := handlers_map.get("websocket"): - route = WebSocketRoute(path=path, route_handler=cast("WebsocketRouteHandler", websocket_handler)) - self.routes.append(route) - routes.append(route) - - if asgi_handler := handlers_map.get("asgi"): - route = ASGIRoute(path=path, route_handler=cast("ASGIRouteHandler", asgi_handler)) - self.routes.append(route) - routes.append(route) - - return routes - - @property - def route_handler_method_map(self) -> dict[str, RouteHandlerMapItem]: - """Map route paths to :class:`~litestar.types.internal_types.RouteHandlerMapItem` - - Returns: - A dictionary mapping paths to route handlers - """ - route_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict) - for route in self.routes: - if isinstance(route, HTTPRoute): - route_map[route.path] = route.route_handler_map # type: ignore[assignment] - else: - route_map[route.path]["websocket" if isinstance(route, WebSocketRoute) else "asgi"] = ( - route.route_handler - ) - - return route_map - - @classmethod - def get_route_handler_map( - cls, - value: RouteHandlerType | Router, - ) -> dict[str, RouteHandlerMapItem]: - """Map route handlers to HTTP methods.""" - if isinstance(value, Router): - return value.route_handler_method_map - - copied_value = copy(value) - if isinstance(value, HTTPRouteHandler): - return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths} - - return { - path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value} - for path in value.paths - } - - def _validate_registration_value(self, value: ControllerRouterHandler) -> RouteHandlerType | Router: - """Ensure values passed to the register method are supported.""" - if is_class_and_subclass(value, Controller): - return value(owner=self).as_router() - - # this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass - if is_class_and_subclass(value, WebsocketListener): - return value(owner=self).to_handler() # pyright: ignore - - if isinstance(value, Router): - if value is self: - raise ImproperlyConfiguredException("Cannot register a router on itself") - - router_copy = deepcopy(value) - router_copy.owner = self - return router_copy - - if isinstance(value, (ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler)): - value.owner = self - return value - - raise ImproperlyConfiguredException( - "Unsupported value passed to `Router.register`. " - "If you passed in a function or method, " - "make sure to decorate it first with one of the routing decorators" - ) - + self.route_handlers = tuple(route_handlers) -def _maybe_add_options_handler(path: str, http_handlers: list[HTTPRouteHandler]) -> list[HTTPRouteHandler]: - handler_methods = {method for handler in http_handlers for method in handler.http_methods} - if "OPTIONS" not in handler_methods: - options_handler = create_options_handler(path=path, allow_methods={*handler_methods, "OPTIONS"}) # pyright: ignore - options_handler.owner = http_handlers[0].owner - return [*http_handlers, options_handler] - return http_handlers + def register(self, value: ControllerRouterHandler) -> None: + """Register a Controller, Route instance or RouteHandler on the router""" + if value is self: + raise ImproperlyConfiguredException("Cannot register a router on itself") + self.route_handlers = (*self.route_handlers, value) diff --git a/litestar/routes/asgi.py b/litestar/routes/asgi.py index 7686bacdd4..c45c67800b 100644 --- a/litestar/routes/asgi.py +++ b/litestar/routes/asgi.py @@ -49,7 +49,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: copy_scope = self.route_handler.copy_scope connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any]( - scope=handler_scope if copy_scope else scope, + scope=handler_scope if copy_scope is True else scope, receive=receive, send=send, ) diff --git a/litestar/routes/http.py b/litestar/routes/http.py index f2c8977c8e..0de4a1ff03 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -8,7 +8,8 @@ if TYPE_CHECKING: from litestar.handlers.http_handlers import HTTPRouteHandler - from litestar.types import Method, Receive, Send + from litestar.types import Receive, Send + from litestar.types.asgi_types import HttpMethodName class HTTPRoute(BaseRoute[HTTPScope]): @@ -32,7 +33,7 @@ def __init__( route_handlers: A list of :class:`~.handlers.HTTPRouteHandler`. """ super().__init__(path=path) - self.route_handler_map: dict[Method, HTTPRouteHandler] = self.create_handler_map(route_handlers) + self.route_handler_map: dict[HttpMethodName, HTTPRouteHandler] = self.create_handler_map(route_handlers) self.route_handlers = tuple(self.route_handler_map.values()) self.methods = tuple(self.route_handler_map) @@ -49,10 +50,10 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: None """ route_handler = self.route_handler_map[scope["method"]] - connection = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send) + connection = route_handler.request_class(scope=scope, receive=receive, send=send) await route_handler.handle(connection=connection) - def create_handler_map(self, route_handlers: Iterable[HTTPRouteHandler]) -> dict[Method, HTTPRouteHandler]: + def create_handler_map(self, route_handlers: Iterable[HTTPRouteHandler]) -> dict[HttpMethodName, HTTPRouteHandler]: """Parse the ``router_handlers`` of this route and return a mapping of http- methods and route handlers. """ diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index 477c1da208..5bca9a4674 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -42,5 +42,5 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N Returns: None """ - socket = self.route_handler.resolve_websocket_class()(scope=scope, receive=receive, send=send) + socket = self.route_handler.websocket_class(scope=scope, receive=receive, send=send) await self.route_handler.handle(connection=socket) diff --git a/litestar/static_files/base.py b/litestar/static_files/base.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/litestar/static_files/config.py b/litestar/static_files/config.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/litestar/testing/request_factory.py b/litestar/testing/request_factory.py index b740c3b380..86be2470cb 100644 --- a/litestar/testing/request_factory.py +++ b/litestar/testing/request_factory.py @@ -41,9 +41,7 @@ def _create_default_route_handler( def _default_route_handler() -> None: ... - handler = handler_decorator("/", sync_to_thread=False, **(handler_kwargs or {}))(_default_route_handler) - handler.owner = app - return handler + return handler_decorator("/", sync_to_thread=False, **(handler_kwargs or {}))(_default_route_handler).merge(app) def _create_default_app() -> Litestar: @@ -304,7 +302,7 @@ def _create_request_with_data( body += chunk scope_state = ScopeState.from_scope(scope) scope_state.body = body - scope_state.exception_handlers = scope["route_handler"].resolve_exception_handlers() + scope_state.exception_handlers = scope["route_handler"].exception_handlers self._create_cookie_header(headers, cookies) scope["headers"] = self._build_headers(headers) return Request(scope=scope) diff --git a/litestar/types/asgi_types.py b/litestar/types/asgi_types.py index c7eb56b6b1..b5a63db5d7 100644 --- a/litestar/types/asgi_types.py +++ b/litestar/types/asgi_types.py @@ -104,7 +104,8 @@ from .internal_types import RouteHandlerType from .serialization import DataContainerType -Method: TypeAlias = Union[Literal["GET", "POST", "DELETE", "PATCH", "PUT", "HEAD", "TRACE", "OPTIONS"], HttpMethod] +HttpMethodName: TypeAlias = Literal["GET", "POST", "DELETE", "PATCH", "PUT", "HEAD", "TRACE", "OPTIONS"] +Method: TypeAlias = Union[HttpMethodName, HttpMethod] ScopeSession: TypeAlias = "EmptyType | Dict[str, Any] | DataContainerType | None" @@ -148,7 +149,7 @@ class BaseScope(HeaderScope): class HTTPScope(BaseScope): """HTTP-ASGI-scope.""" - method: Method + method: HttpMethodName type: Literal[ScopeType.HTTP] diff --git a/litestar/types/callable_types.py b/litestar/types/callable_types.py index 0f07295cc4..2c723c99fc 100644 --- a/litestar/types/callable_types.py +++ b/litestar/types/callable_types.py @@ -20,20 +20,31 @@ ExceptionT = TypeVar("ExceptionT", bound=Exception) AfterExceptionHookHandler: TypeAlias = "Callable[[ExceptionT, Scope], SyncOrAsyncUnion[None]]" -AfterRequestHookHandler: TypeAlias = ( - "Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]" +AsyncAfterRequestHookHandler: TypeAlias = ( + "Callable[[ASGIApp], Awaitable[ASGIApp]] | Callable[[Response], Awaitable[Response]]" ) -AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" +SyncAfterRequestHookHandler: TypeAlias = "Callable[[ASGIApp], ASGIApp] | Callable[[Response], Response]" +AfterRequestHookHandler: TypeAlias = "AsyncAfterRequestHookHandler | SyncAfterRequestHookHandler" + +AsyncAfterResponseHookHandler: TypeAlias = "Callable[[Request], Awaitable[None]]" +SyncAfterResponseHookHandler: TypeAlias = "Callable[[Request], None]" +AfterResponseHookHandler: TypeAlias = "AsyncAfterResponseHookHandler | SyncAfterResponseHookHandler" + +AsyncBeforeRequestHookHandler: TypeAlias = "Callable[[Request], Awaitable[Any]]" +BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" + + AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]] AnyCallable: TypeAlias = Callable[..., Any] AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]" BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]" -BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" CacheKeyBuilder: TypeAlias = "Callable[[Request], str]" ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]" ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]" GetLogger: TypeAlias = "Callable[..., Logger]" -Guard: TypeAlias = "Callable[[ASGIConnection, BaseRouteHandler], SyncOrAsyncUnion[None]]" +AsyncGuard: TypeAlias = "Callable[[ASGIConnection, BaseRouteHandler], Awaitable[None]]" +SyncGuard: TypeAlias = "Callable[[ASGIConnection, BaseRouteHandler], None]" +Guard: TypeAlias = "AsyncGuard | SyncGuard" LifespanHook: TypeAlias = "Callable[[Litestar], SyncOrAsyncUnion[Any]] | Callable[[], SyncOrAsyncUnion[Any]]" OnAppInitHandler: TypeAlias = "Callable[[AppConfig], AppConfig]" OperationIDCreator: TypeAlias = "Callable[[HTTPRouteHandler, Method, list[str | PathParameterDefinition]], str]" diff --git a/litestar/types/internal_types.py b/litestar/types/internal_types.py index d473c22677..93539771f0 100644 --- a/litestar/types/internal_types.py +++ b/litestar/types/internal_types.py @@ -18,6 +18,7 @@ from litestar.app import Litestar from litestar.controller import Controller + from litestar.handlers import BaseRouteHandler from litestar.handlers.asgi_handlers import ASGIRouteHandler from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.handlers.websocket_handlers import WebsocketRouteHandler @@ -26,10 +27,11 @@ from litestar.template.config import EngineType from litestar.types import Method + ReservedKwargs: TypeAlias = Literal["request", "socket", "headers", "query", "cookies", "state", "data"] RouteHandlerType: TypeAlias = "HTTPRouteHandler | WebsocketRouteHandler | ASGIRouteHandler" ControllerRouterHandler: TypeAlias = "type[Controller] | RouteHandlerType | Router | Callable[..., Any]" -RouteHandlerMapItem: TypeAlias = 'dict[Method | Literal["websocket", "asgi"], RouteHandlerType]' +RouteHandlerMapItem: TypeAlias = 'dict[Method | Literal["websocket", "asgi"], BaseRouteHandler]' TemplateConfigType: TypeAlias = "TemplateConfig[EngineType]" # deprecated diff --git a/litestar/typing.py b/litestar/typing.py index 37dec75825..8a77b869c8 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -3,7 +3,6 @@ import dataclasses import warnings from collections import abc -from copy import deepcopy from dataclasses import dataclass, is_dataclass, replace from enum import Enum from inspect import Parameter, Signature @@ -21,7 +20,6 @@ NewType, NotRequired, Required, - Self, TypeAliasType, get_args, get_origin, @@ -144,9 +142,6 @@ class FieldDefinition: name: str """Field name.""" - def __deepcopy__(self, memo: dict[str, Any]) -> Self: - return type(self)(**{attr: deepcopy(getattr(self, attr)) for attr in self.__slots__}) - def __eq__(self, other: Any) -> bool: if not isinstance(other, FieldDefinition): return False diff --git a/litestar/utils/predicates.py b/litestar/utils/predicates.py index 4a110fcba9..3467c97f3d 100644 --- a/litestar/utils/predicates.py +++ b/litestar/utils/predicates.py @@ -73,6 +73,10 @@ class instances with ``async def __call__()`` defined. Returns: Bool determining if type of ``value`` is an awaitable. """ + from litestar.utils.sync import AsyncCallable + + if isinstance(value, AsyncCallable): + return True value = unwrap_partial(value) return iscoroutinefunction(value) or ( diff --git a/litestar/utils/scope/__init__.py b/litestar/utils/scope/__init__.py index 76af9ecad0..625e6dc79a 100644 --- a/litestar/utils/scope/__init__.py +++ b/litestar/utils/scope/__init__.py @@ -22,14 +22,11 @@ def get_serializer_from_scope(scope: Scope) -> Serializer: route_handler = scope["route_handler"] app = scope["litestar_app"] - if hasattr(route_handler, "resolve_type_encoders"): - type_encoders = route_handler.resolve_type_encoders() - else: - type_encoders = app.type_encoders or {} + type_encoders = route_handler.type_encoders if hasattr(route_handler, "type_encoders") else app.type_encoders or {} if response_class := ( - route_handler.resolve_response_class() # pyright: ignore - if hasattr(route_handler, "resolve_response_class") + route_handler.response_class # pyright: ignore + if hasattr(route_handler, "response_class") else app.response_class ): type_encoders = {**type_encoders, **(response_class.type_encoders or {})} diff --git a/litestar/utils/signature.py b/litestar/utils/signature.py index becd1bfea7..2f2c5f55e0 100644 --- a/litestar/utils/signature.py +++ b/litestar/utils/signature.py @@ -2,7 +2,6 @@ import sys import typing -from copy import deepcopy from dataclasses import dataclass, replace from inspect import Signature, getmembers, isclass, ismethod from itertools import chain @@ -193,13 +192,6 @@ class ParsedSignature: original_signature: Signature """The raw signature as returned by :func:`inspect.signature`""" - def __deepcopy__(self, memo: dict[str, Any]) -> Self: - return type(self)( - parameters={k: deepcopy(v) for k, v in self.parameters.items()}, - return_type=deepcopy(self.return_type), - original_signature=deepcopy(self.original_signature), - ) - @classmethod def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self: """Parse a function signature. diff --git a/litestar/utils/sync.py b/litestar/utils/sync.py index 7c91845a07..0f1cf88b56 100644 --- a/litestar/utils/sync.py +++ b/litestar/utils/sync.py @@ -8,6 +8,7 @@ Iterable, Iterator, TypeVar, + overload, ) from typing_extensions import ParamSpec @@ -22,7 +23,15 @@ T = TypeVar("T") -def ensure_async_callable(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: +@overload +def ensure_async_callable(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: ... + + +@overload +def ensure_async_callable(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: ... + + +def ensure_async_callable(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: # pyright: ignore """Ensure that ``fn`` is an asynchronous callable. If it is an asynchronous, return the original object, else wrap it in an ``AsyncCallable`` diff --git a/pyproject.toml b/pyproject.toml index 83ffcc7082..69fdf823ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,6 +237,7 @@ filterwarnings = [ "ignore: Dropping max_length:litestar.exceptions.LitestarWarning:litestar.contrib.piccolo", "ignore: Python Debugger on exception enabled:litestar.exceptions.LitestarWarning:", "ignore: datetime.datetime.utcnow:DeprecationWarning:time_machine", + "ignore:Registering routes.*:litestar.exceptions.base_exceptions.LitestarWarning" ] markers = [ "sqlalchemy_integration: SQLAlchemy integration tests", @@ -410,7 +411,7 @@ lint.select = [ "SIM", # flake8-simplify "T10", # flake8-debugger "T20", # flake8-print - "TC", # flake8-type-checking + "TCH", # flake8-type-checking "TID", # flake8-tidy-imports "UP", # pyupgrade "W", # pycodestyle - warning @@ -461,7 +462,7 @@ classmethod-decorators = [ known-first-party = ["litestar", "tests", "examples"] [tool.ruff.lint.per-file-ignores] -"docs/**/*.*" = ["S", "B", "DTZ", "A", "TC", "ERA", "D", "RET"] +"docs/**/*.*" = ["S", "B", "DTZ", "A", "TCH", "ERA", "D", "RET"] "docs/examples/**" = ["T201"] "docs/examples/application_hooks/before_send_hook.py" = ["UP006"] "docs/examples/contrib/sqlalchemy/plugins/**/*.*" = ["UP006"] @@ -494,7 +495,7 @@ known-first-party = ["litestar", "tests", "examples"] "S", "S101", "SIM", - "TC", + "TCH", "TRY", "E721", ] diff --git a/tests/e2e/test_router_registration.py b/tests/e2e/test_router_registration.py index 357411108f..cdb33de8df 100644 --- a/tests/e2e/test_router_registration.py +++ b/tests/e2e/test_router_registration.py @@ -11,7 +11,6 @@ get, patch, post, - put, websocket, ) from litestar import ( @@ -46,7 +45,7 @@ async def ws(self, socket: WebSocket) -> None: def test_register_with_controller_class(controller: Type[Controller]) -> None: - router = Router(path="/base", route_handlers=[controller]) + router = Litestar(path="/base", route_handlers=[controller], openapi_config=None) assert len(router.routes) == 3 for route in router.routes: if isinstance(route, HTTPRoute): @@ -58,31 +57,9 @@ def test_register_with_controller_class(controller: Type[Controller]) -> None: assert route.path == "/base/test" -def test_register_controller_on_different_routers(controller: Type[Controller]) -> None: - first_router = Router(path="/first", route_handlers=[controller]) - second_router = Router(path="/second", route_handlers=[controller]) - third_router = Router(path="/third", route_handlers=[controller]) - - for router in (first_router, second_router, third_router): - for route in router.routes: - if hasattr(route, "route_handlers"): - for route_handler in [ - handler - for handler in route.route_handlers # pyright: ignore - if handler.handler_name != "options_handler" - ]: - assert route_handler.owner is not None - assert route_handler.owner.owner is not None - assert route_handler.owner.owner is router - else: - assert route.route_handler.owner is not None # pyright: ignore - assert route.route_handler.owner.owner is not None # pyright: ignore - assert route.route_handler.owner.owner is router # pyright: ignore - - def test_register_with_router_instance(controller: Type[Controller]) -> None: top_level_router = Router(path="/top-level", route_handlers=[controller]) - base_router = Router(path="/base", route_handlers=[top_level_router]) + base_router = Litestar(path="/base", route_handlers=[top_level_router], openapi_config=None) assert len(base_router.routes) == 3 for route in base_router.routes: @@ -108,7 +85,11 @@ def second_route_handler() -> None: def third_route_handler() -> None: pass - router = Router(path="/base", route_handlers=[first_route_handler, second_route_handler, third_route_handler]) + router = Litestar( + path="/base", + route_handlers=[first_route_handler, second_route_handler, third_route_handler], + openapi_config=None, + ) assert len(router.routes) == 2 for route in router.routes: if isinstance(route, HTTPRoute): @@ -132,7 +113,7 @@ def second_route_handler(self) -> None: pass with pytest.raises(ImproperlyConfiguredException): - Router(path="/base", route_handlers=[MyCustomClass]) + Litestar(path="/base", route_handlers=[MyCustomClass]) def test_register_already_registered_router() -> None: @@ -148,42 +129,11 @@ def test_register_router_on_itself() -> None: router.register(router) -def test_route_handler_method_view(controller: Type[Controller]) -> None: - @get(path="/root") - def handler() -> None: ... - - def _handler() -> None: ... - - put_handler = put("/modify")(_handler) - post_handler = post("/send")(_handler) - - first_router = Router(path="/first", route_handlers=[controller, post_handler, put_handler]) - second_router = Router(path="/second", route_handlers=[controller, post_handler, put_handler]) - - app = Litestar(route_handlers=[first_router, second_router, handler]) +def test_register_app_on_itself() -> None: + app = Litestar(path="/first", route_handlers=[]) - assert app.route_handler_method_view[str(handler)] == ["/root"] - assert app.route_handler_method_view[str(controller.get_method)] == [ # type: ignore[attr-defined] - "/first/test", - "/second/test", - ] - - assert app.route_handler_method_view[str(controller.ws)] == [ # type: ignore[attr-defined] - "/first/test/socket", - "/second/test/socket", - ] - assert app.route_handler_method_view[str(put_handler)] == [ - "/first/send", - "/first/modify", - "/second/send", - "/second/modify", - ] - assert app.route_handler_method_view[str(post_handler)] == [ - "/first/send", - "/first/modify", - "/second/send", - "/second/modify", - ] + with pytest.raises(ImproperlyConfiguredException): + app.register(app) def test_missing_path_param_type(controller: Type[Controller]) -> None: @@ -193,5 +143,5 @@ def test_missing_path_param_type(controller: Type[Controller]) -> None: def handler() -> None: ... with pytest.raises(ImproperlyConfiguredException) as exc: - Router(path="/", route_handlers=[handler]) + Litestar(route_handlers=[handler]) assert missing_path_type in exc.value.args[0] diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 0a6e139094..d0f93ee517 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from litestar import Litestar, MediaType, Request, Response, get -from litestar._asgi.asgi_router import ASGIRouter from litestar.config.app import AppConfig, ExperimentalFeatures from litestar.config.response_cache import ResponseCacheConfig from litestar.contrib.sqlalchemy.plugins import SQLAlchemySerializationPlugin @@ -28,7 +27,6 @@ ) from litestar.logging.config import LoggingConfig from litestar.plugins import CLIPluginProtocol -from litestar.router import Router from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import TestClient, create_test_client @@ -171,8 +169,6 @@ def test_app_config_object_used(app_config_object: AppConfig, monkeypatch: pytes # Things that we don't actually need to call for this test monkeypatch.setattr(Litestar, "register", MagicMock()) monkeypatch.setattr(Litestar, "_create_asgi_handler", MagicMock()) - monkeypatch.setattr(Router, "__init__", MagicMock()) - monkeypatch.setattr(ASGIRouter, "__init__", MagicMock(return_value=None)) # instantiates the app with an `on_app_config` that returns our patched `AppConfig` object. Litestar(on_app_init=[MagicMock(return_value=app_config_object)]) diff --git a/tests/unit/test_connection/test_websocket.py b/tests/unit/test_connection/test_websocket.py index 2fe96f9eb9..e41c356324 100644 --- a/tests/unit/test_connection/test_websocket.py +++ b/tests/unit/test_connection/test_websocket.py @@ -67,7 +67,7 @@ async def handler(socket: WebSocket[Any, Any, State]) -> None: assert dict(ws.scope["headers"])[b"test"] == b"hello-world" -async def test_custom_request_class() -> None: +async def test_custom_websocket_class() -> None: value: Any = {} class MyWebSocket(WebSocket[Any, Any, State]): diff --git a/tests/unit/test_di.py b/tests/unit/test_di.py index 781e805cb1..cdd286ada7 100644 --- a/tests/unit/test_di.py +++ b/tests/unit/test_di.py @@ -165,3 +165,16 @@ def test_raises_when_dependency_is_not_callable() -> None: def test_raises_when_generator_dependency_is_cached(dep: Any) -> None: with pytest.raises(ImproperlyConfiguredException): Provide(dep, use_cache=True) + + +def test_provide_raises_on_unsafe_signature_access() -> None: + async def foo() -> None: + pass + + provide = Provide(foo) + + with pytest.raises(ValueError): + provide.signature_model + + with pytest.raises(ValueError): + provide.parsed_fn_signature diff --git a/tests/unit/test_dto/test_integration.py b/tests/unit/test_dto/test_integration.py index 99f2cb1f6e..3b12a5cf8c 100644 --- a/tests/unit/test_dto/test_integration.py +++ b/tests/unit/test_dto/test_integration.py @@ -98,9 +98,10 @@ def test_enable_experimental_backend_override_in_dto_config(ModelDataDTO: type[A def handler(data: Model) -> Model: return data - Litestar(route_handlers=[handler]) + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["POST"] - backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr] + backend = resolved_handler.data_dto._dto_backends[resolved_handler.handler_id]["data_backend"] # type: ignore[union-attr] assert isinstance(backend, DTOBackend) @@ -111,7 +112,8 @@ def test_use_codegen_backend_by_default(ModelDataDTO: type[AbstractDTO]) -> None def handler(data: Model) -> Model: return data - Litestar(route_handlers=[handler]) + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["POST"] - backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr] + backend = resolved_handler.data_dto._dto_backends[resolved_handler.handler_id]["data_backend"] # type: ignore[union-attr] assert isinstance(backend, DTOBackend) diff --git a/tests/unit/test_guards.py b/tests/unit/test_guards.py index 9224a8f75c..9cbaf6a51e 100644 --- a/tests/unit/test_guards.py +++ b/tests/unit/test_guards.py @@ -2,35 +2,47 @@ import pytest -from litestar import Litestar, Router, asgi, get, websocket +from litestar import Controller, Litestar, Router, asgi, get, websocket from litestar.connection import WebSocket from litestar.exceptions import PermissionDeniedException, WebSocketDisconnect from litestar.response.base import ASGIResponse from litestar.status_codes import HTTP_200_OK, HTTP_403_FORBIDDEN from litestar.testing import create_test_client -from litestar.types import Receive, Scope, Send +from litestar.types import Guard, Receive, Scope, Send if TYPE_CHECKING: from litestar.connection import ASGIConnection from litestar.handlers.base import BaseRouteHandler -async def local_guard(_: "ASGIConnection", route_handler: "BaseRouteHandler") -> None: - if not route_handler.opt or not route_handler.opt.get("allow_all"): - raise PermissionDeniedException("local") +@pytest.fixture() +def local_guard() -> Guard: + async def local_guard_fn(_: "ASGIConnection", route_handler: "BaseRouteHandler") -> None: + if not route_handler.opt or not route_handler.opt.get("allow_all"): + raise PermissionDeniedException("local") + return local_guard_fn -def router_guard(connection: "ASGIConnection", _: "BaseRouteHandler") -> None: - if not connection.headers.get("Authorization-Router"): - raise PermissionDeniedException("router") +@pytest.fixture() +def router_guard() -> Guard: + async def router_guard_fn(connection: "ASGIConnection", _: "BaseRouteHandler") -> None: + if not connection.headers.get("Authorization-Router"): + raise PermissionDeniedException("router") -def app_guard(connection: "ASGIConnection", _: "BaseRouteHandler") -> None: - if not connection.headers.get("Authorization"): - raise PermissionDeniedException("app") + return router_guard_fn -def test_guards_with_http_handler() -> None: +@pytest.fixture() +def app_guard() -> Guard: + async def app_guard_fn(connection: "ASGIConnection", _: "BaseRouteHandler") -> None: + if not connection.headers.get("Authorization"): + raise PermissionDeniedException("app") + + return app_guard_fn + + +def test_guards_with_http_handler(app_guard: Guard, local_guard: Guard) -> None: @get(path="/secret", guards=[local_guard]) def my_http_route_handler() -> None: ... @@ -46,7 +58,7 @@ def my_http_route_handler() -> None: ... assert response.status_code == HTTP_200_OK -def test_guards_with_asgi_handler() -> None: +def test_guards_with_asgi_handler(app_guard: Guard, local_guard: Guard) -> None: @asgi(path="/secret", guards=[local_guard]) async def my_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: response = ASGIResponse(body=b'{"hello": "world"}') @@ -64,7 +76,7 @@ async def my_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: assert response.status_code == HTTP_200_OK -def test_guards_with_websocket_handler() -> None: +def test_guards_with_websocket_handler(local_guard: Guard) -> None: @websocket(path="/", guards=[local_guard]) async def my_websocket_route_handler(socket: WebSocket) -> None: await socket.accept() @@ -84,26 +96,40 @@ async def my_websocket_route_handler(socket: WebSocket) -> None: ws.send_json({"data": "123"}) -def test_guards_layering_for_same_route_handler() -> None: +def test_guard_ordering(local_guard: Guard, router_guard: Guard, app_guard: Guard) -> None: + async def controller_guard(_: "ASGIConnection", route_handler: "BaseRouteHandler") -> None: + pass + + class MyController(Controller): + guards = [controller_guard] + + @get(path="/http", guards=[local_guard]) + def http_route_handler(self) -> None: ... + + router = Router(path="/router", route_handlers=[MyController], guards=[router_guard]) + app = Litestar(route_handlers=[router], guards=[app_guard]) + + assert app.asgi_router.root_route_map_node.children["/router/http"].asgi_handlers["GET"][1].guards == ( + app_guard, + router_guard, + controller_guard, + local_guard, + ) + + +def test_guards_layering_for_same_route_handler(local_guard: Guard, router_guard: Guard, app_guard: Guard) -> None: @get(path="/http", guards=[local_guard]) def http_route_handler() -> None: ... router = Router(path="/router", route_handlers=[http_route_handler], guards=[router_guard]) app = Litestar(route_handlers=[http_route_handler, router], guards=[app_guard]) - assert ( - len( - app.asgi_router.root_route_map_node.children["/http"] - .asgi_handlers["GET"][1] # type: ignore[arg-type] - ._resolved_guards - ) - == 2 + assert app.asgi_router.root_route_map_node.children["/http"].asgi_handlers["GET"][1].guards == ( + app_guard, + local_guard, ) - assert ( - len( - app.asgi_router.root_route_map_node.children["/router/http"] - .asgi_handlers["GET"][1] # type: ignore[arg-type] - ._resolved_guards - ) - == 3 + assert app.asgi_router.root_route_map_node.children["/router/http"].asgi_handlers["GET"][1].guards == ( + app_guard, + router_guard, + local_guard, ) diff --git a/tests/unit/test_handlers/test_asgi_handlers/test_validations.py b/tests/unit/test_handlers/test_asgi_handlers/test_validations.py index 3e1e245fa8..ee1ab736a8 100644 --- a/tests/unit/test_handlers/test_asgi_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_asgi_handlers/test_validations.py @@ -17,28 +17,28 @@ async def fn_without_scope_arg(receive: "Receive", send: "Send") -> None: with pytest.raises(ImproperlyConfiguredException): handler = asgi(path="/")(fn_without_scope_arg) - handler.on_registration(Litestar(), ASGIRoute(path="/", route_handler=handler)) + handler.on_registration(ASGIRoute(path="/", route_handler=handler), app=Litestar()) async def fn_without_receive_arg(scope: "Scope", send: "Send") -> None: pass with pytest.raises(ImproperlyConfiguredException): handler = asgi(path="/")(fn_without_receive_arg) - handler.on_registration(Litestar(), ASGIRoute(path="/", route_handler=handler)) + handler.on_registration(ASGIRoute(path="/", route_handler=handler), app=Litestar()) async def fn_without_send_arg(scope: "Scope", receive: "Receive") -> None: pass with pytest.raises(ImproperlyConfiguredException): handler = asgi(path="/")(fn_without_send_arg) - handler.on_registration(Litestar(), ASGIRoute(path="/", route_handler=handler)) + handler.on_registration(ASGIRoute(path="/", route_handler=handler), app=Litestar()) async def fn_with_return_annotation(scope: "Scope", receive: "Receive", send: "Send") -> dict: return {} with pytest.raises(ImproperlyConfiguredException): handler = asgi(path="/")(fn_with_return_annotation) - handler.on_registration(Litestar(), ASGIRoute(path="/", route_handler=handler)) + handler.on_registration(ASGIRoute(path="/", route_handler=handler), app=Litestar()) asgi_handler_with_no_fn = asgi(path="/") @@ -50,4 +50,4 @@ def sync_fn(scope: "Scope", receive: "Receive", send: "Send") -> None: with pytest.raises(ImproperlyConfiguredException): handler = asgi(path="/")(sync_fn) # type: ignore[arg-type] - handler.on_registration(Litestar(), ASGIRoute(path="/", route_handler=handler)) + handler.on_registration(ASGIRoute(path="/", route_handler=handler), app=Litestar()) diff --git a/tests/unit/test_handlers/test_base_handlers/test_resolution.py b/tests/unit/test_handlers/test_base_handlers/test_resolution.py index 12809610a4..37626bdc92 100644 --- a/tests/unit/test_handlers/test_base_handlers/test_resolution.py +++ b/tests/unit/test_handlers/test_base_handlers/test_resolution.py @@ -1,7 +1,9 @@ from typing import Awaitable, Callable +from unittest.mock import AsyncMock from litestar import Controller, Litestar, Router, get from litestar.di import Provide +from litestar.params import Parameter def test_resolve_dependencies_without_provide() -> None: @@ -15,7 +17,7 @@ async def bar() -> None: async def handler() -> None: pass - assert handler.resolve_dependencies() == {"foo": Provide(foo), "bar": Provide(bar)} + assert handler.dependencies == {"foo": Provide(foo), "bar": Provide(bar)} def function_factory() -> Callable[[], Awaitable[None]]: @@ -46,7 +48,7 @@ async def handler(self) -> None: assert handler_map handler = handler_map["handler"] - assert handler.resolve_dependencies() == { + assert handler.dependencies == { "app": Provide(app_dependency), "router": Provide(router_dependency), "controller": Provide(controller_dependency), @@ -54,16 +56,80 @@ async def handler(self) -> None: } -def test_resolve_dependencies_cached() -> None: - dependency = Provide(function_factory()) +def test_resolve_type_encoders() -> None: + @get("/", type_encoders={int: str}) + def handler() -> None: + pass - @get(dependencies={"foo": dependency}) - async def handler() -> None: + assert handler.resolve_type_encoders() == {int: str} + + +def test_resolve_type_decoders() -> None: + type_decoders = [(lambda t: True, lambda v, t: t)] + + @get("/", type_decoders=type_decoders) + def handler() -> None: pass - @get(dependencies={"foo": dependency}) - async def handler_2() -> None: + assert handler.resolve_type_decoders() == type_decoders + + +def test_resolve_parameters() -> None: + parameters = {"foo": Parameter()} + + @get("/") + def handler() -> None: + pass + + handler = handler.merge(Router("/", parameters=parameters, route_handlers=[])) + assert handler.resolve_layered_parameters() == handler.parameter_field_definitions + + +def test_resolve_guards() -> None: + guard = AsyncMock() + + @get("/", guards=[guard]) + def handler() -> None: + pass + + assert handler.resolve_guards() == (guard,) + + +def test_resolve_dependencies() -> None: + dependency = AsyncMock() + + @get("/", dependencies={"foo": dependency}) + def handler() -> None: + pass + + assert handler.resolve_dependencies() == handler.dependencies + + +def test_resolve_middleware() -> None: + middleware = AsyncMock() + + @get("/", middleware=[middleware]) + def handler() -> None: + pass + + assert handler.resolve_middleware() == handler.middleware + + +def test_exception_handlers() -> None: + exception_handler = AsyncMock() + + @get("/", exception_handlers={ValueError: exception_handler}) + def handler() -> None: + pass + + assert handler.resolve_exception_handlers() == {ValueError: exception_handler} + + +def test_resolve_signature_namespace() -> None: + namespace = {"foo": object()} + + @get("/", signature_namespace=namespace) + def handler() -> None: pass - assert handler.resolve_dependencies() is handler.resolve_dependencies() - assert handler_2.resolve_dependencies() is handler_2.resolve_dependencies() + assert handler.resolve_signature_namespace() == namespace diff --git a/tests/unit/test_handlers/test_http_handlers/test_head.py b/tests/unit/test_handlers/test_http_handlers/test_head.py index af17b045c9..adf2b81ab4 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_head.py +++ b/tests/unit/test_handlers/test_http_handlers/test_head.py @@ -28,8 +28,7 @@ def test_head_decorator_raises_validation_error_if_body_is_declared() -> None: def handler() -> dict: return {} - handler.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[handler])) - Litestar(route_handlers=[handler]) + handler.on_registration(HTTPRoute(path="/", route_handlers=[handler]), app=Litestar()) def test_head_decorator_none_response_return_value_allowed() -> None: @@ -39,15 +38,17 @@ def test_head_decorator_none_response_return_value_allowed() -> None: class MyResponse(Generic[T], Response[T]): pass - @head("/1") + @head("/") def handler() -> Response[None]: return Response(None) - @head("/2") + Litestar([handler]) + + @head("/") def handler_subclass() -> MyResponse[None]: return MyResponse(None) - Litestar(route_handlers=[handler, handler_subclass]) + Litestar([handler_subclass]) def test_head_decorator_does_not_raise_for_file_response() -> None: @@ -57,8 +58,6 @@ def handler() -> "File": Litestar(route_handlers=[handler]) - handler.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[handler])) - def test_head_decorator_does_not_raise_for_asgi_file_response() -> None: @head("/") @@ -66,5 +65,3 @@ def handler() -> ASGIFileResponse: return ASGIFileResponse(file_path=Path("test_head.py")) Litestar(route_handlers=[handler]) - - handler.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[handler])) diff --git a/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py b/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py index 326972ecac..e0be8744a8 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py +++ b/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py @@ -19,8 +19,8 @@ def dummy_method() -> None: http_method=st.one_of(st.sampled_from(HttpMethod), st.lists(st.sampled_from(HttpMethod))), media_type=st.sampled_from(MediaType), include_in_schema=st.booleans(), - response_class=st.one_of(st.none(), st.just(Response)), - response_headers=st.one_of(st.none(), st.builds(list)), + response_class=st.one_of(st.just(Response)), + response_headers=st.builds(frozenset), status_code=st.one_of(st.none(), st.integers(min_value=200, max_value=204)), path=st.one_of(st.none(), st.text()), ) diff --git a/tests/unit/test_handlers/test_http_handlers/test_media_type.py b/tests/unit/test_handlers/test_http_handlers/test_media_type.py index a2ff10b60e..b249d06aef 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_media_type.py +++ b/tests/unit/test_handlers/test_http_handlers/test_media_type.py @@ -4,7 +4,6 @@ import pytest from litestar import Litestar, MediaType, get -from litestar.routes import HTTPRoute from tests.models import DataclassPerson @@ -37,7 +36,6 @@ def test_media_type_inference(annotation: Any, expected_media_type: MediaType) - def handler() -> annotation: return None - Litestar(route_handlers=[handler]) - - handler.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[handler])) - assert handler.media_type == expected_media_type + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["GET"] + assert resolved_handler.media_type == expected_media_type # type: ignore[attr-defined] diff --git a/tests/unit/test_handlers/test_http_handlers/test_options.py b/tests/unit/test_handlers/test_http_handlers/test_options.py new file mode 100644 index 0000000000..60ed61048c --- /dev/null +++ b/tests/unit/test_handlers/test_http_handlers/test_options.py @@ -0,0 +1,27 @@ +import pytest + +from litestar import Router, get +from litestar.testing import create_test_client + + +@pytest.mark.xfail(reason="broken behaviour that never really worked") +def test_option_handler_inherits_layer_config() -> None: + # Currently broken. The reason this cannot work reliably is that when auto-creating + # OPTIONS handlers, we *assume* that it should inherit the configuration of + # path-adjacent route handlers, e.g. if a route handler is defined on '/', and + # registered on a router with the path '/one', it would receive the configuration + # from that router. However, since it is possible to have multiple handlers per + # path, which can also be registered with different routers, we would still only + # create one 'OPTIONS' handler for each path, even if multiple routers exist with + # handlers for that path. In this case, it is unclear which configuration the + # 'OPTIONS' handler should receive + @get("/") + def handler() -> None: + return None + + router = Router(path="/router", route_handlers=[handler], response_headers={"router": "router"}) + + with create_test_client(route_handlers=[handler, router], response_headers={"app": "app"}) as client: + res = client.options("/") + assert res.headers.get("app") == "app" + assert res.headers.get("router") == "router" diff --git a/tests/unit/test_handlers/test_http_handlers/test_resolution.py b/tests/unit/test_handlers/test_http_handlers/test_resolution.py index f294599b39..634243e2a7 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_resolution.py +++ b/tests/unit/test_handlers/test_http_handlers/test_resolution.py @@ -1,6 +1,9 @@ +from unittest.mock import Mock + import pytest -from litestar import Controller, Litestar, Router, post +from litestar import Controller, Litestar, Request, Response, Router, get, post +from litestar.datastructures import ResponseHeader from litestar.exceptions import ImproperlyConfiguredException from litestar.types import Empty @@ -23,11 +26,12 @@ def controller_handler(self) -> None: router = Router("/", route_handlers=[router_handler], request_max_body_size=1) app = Litestar(route_handlers=[app_handler, router, MyController], request_max_body_size=3) - assert router_handler.resolve_request_max_body_size() == 1 - assert app_handler.resolve_request_max_body_size() == 3 - assert ( - next(r for r in app.routes if r.path == "/3").route_handler_map["POST"].resolve_request_max_body_size() == 2 # type: ignore[union-attr] - ) + handler_1 = next(r for r in app.routes if r.path == "/1").route_handler_map["POST"] # type: ignore[union-attr] + handler_2 = next(r for r in app.routes if r.path == "/2").route_handler_map["POST"] # type: ignore[union-attr] + handler_3 = next(r for r in app.routes if r.path == "/3").route_handler_map["POST"] # type: ignore[union-attr] + assert handler_1.request_max_body_size == handler_1.resolve_request_max_body_size() == 1 + assert handler_2.request_max_body_size == handler_2.resolve_request_max_body_size() == 3 + assert handler_3.request_max_body_size == handler_3.resolve_request_max_body_size() == 2 def test_resolve_request_max_body_size_none() -> None: @@ -36,7 +40,7 @@ def router_handler() -> None: pass Litestar([router_handler]) - assert router_handler.resolve_request_max_body_size() is None + assert router_handler.request_max_body_size is None def test_resolve_request_max_body_size_app_default() -> None: @@ -46,7 +50,11 @@ def router_handler() -> None: app = Litestar(route_handlers=[router_handler]) - assert router_handler.resolve_request_max_body_size() == app.request_max_body_size == 10_000_000 + assert ( + next(r for r in app.routes if r.path == "/").route_handler_map["POST"].request_max_body_size # type: ignore[union-attr] + == app.request_max_body_size + == 10_000_000 + ) def test_resolve_request_max_body_size_empty_on_all_layers_raises() -> None: @@ -54,13 +62,89 @@ def test_resolve_request_max_body_size_empty_on_all_layers_raises() -> None: def handler_one() -> None: pass - Litestar([handler_one], request_max_body_size=Empty) # type: ignore[arg-type] - with pytest.raises(ImproperlyConfiguredException): - handler_one.resolve_request_max_body_size() + with pytest.raises(ImproperlyConfiguredException, match="'request_max_body_size' set to 'Empty'"): + Litestar([handler_one], request_max_body_size=Empty) # type: ignore[arg-type] - @post("/") - def handler_two() -> None: + +def test_resolve_request_class() -> None: + @get() + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + assert app.route_handler_method_map["/"]["GET"].resolve_request_class() is Request # type: ignore[attr-defined] + + +def test_resolve_response_class() -> None: + @get() + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + assert app.route_handler_method_map["/"]["GET"].resolve_response_class() is Response # type: ignore[attr-defined] + + +def test_resolve_response_headers() -> None: + @get(response_headers={"foo": "bar"}) + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + assert app.route_handler_method_map["/"]["GET"].resolve_response_headers() == frozenset( # type: ignore[attr-defined] + [ResponseHeader(name="foo", value="bar")] + ) + + +def test_resolve_before_request() -> None: + before_request = Mock() + + @get(before_request=before_request) + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["GET"] + assert resolved_handler.resolve_before_request() is resolved_handler.before_request # type: ignore[attr-defined] + + +def test_resolve_after_response() -> None: + after_response = Mock() + + @get(after_response=after_response) + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["GET"] + assert resolved_handler.resolve_after_response() is resolved_handler.after_response # type: ignore[attr-defined] + + +def test_resolve_include_in_schema() -> None: + @get(include_in_schema=False) + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + assert app.route_handler_method_map["/"]["GET"].resolve_include_in_schema() is False # type: ignore[attr-defined] + + +def test_resolve_security() -> None: + security = [{"foo": ["bar"]}] + + @get(security=security) + async def handler() -> None: + pass + + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["GET"] + assert resolved_handler.resolve_security() == resolved_handler.security # type: ignore[attr-defined] + + +def test_resolve_tags() -> None: + @get(tags=["foo"]) + async def handler() -> None: pass - with pytest.raises(ImproperlyConfiguredException): - handler_two.resolve_request_max_body_size() + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/"]["GET"] + assert resolved_handler.resolve_tags() == resolved_handler.tags # type: ignore[attr-defined] diff --git a/tests/unit/test_handlers/test_http_handlers/test_validations.py b/tests/unit/test_handlers/test_http_handlers/test_validations.py index 94c2efe7c5..0c80605717 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_http_handlers/test_validations.py @@ -10,7 +10,6 @@ from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.params import Body from litestar.response import File, Redirect -from litestar.routes import HTTPRoute from litestar.status_codes import ( HTTP_100_CONTINUE, HTTP_200_OK, @@ -39,7 +38,7 @@ def test_route_handler_validation_http_method() -> None: async def test_function_validation() -> None: - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises(ImproperlyConfiguredException, match="Missing return type annotation "): @get(path="/") def method_with_no_annotation(): # type: ignore[no-untyped-def] @@ -47,11 +46,10 @@ def method_with_no_annotation(): # type: ignore[no-untyped-def] Litestar(route_handlers=[method_with_no_annotation]) - method_with_no_annotation.on_registration( - Litestar(), HTTPRoute(path="/", route_handlers=[method_with_no_annotation]) - ) - - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises( + ImproperlyConfiguredException, + match="A status code 204, 304 or in the range below 200 does not support a response body", + ): @delete(path="/") def method_with_no_content() -> Dict[str, str]: @@ -59,9 +57,10 @@ def method_with_no_content() -> Dict[str, str]: Litestar(route_handlers=[method_with_no_content]) - method_with_no_content.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[method_with_no_content])) - - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises( + ImproperlyConfiguredException, + match="A status code 204, 304 or in the range below 200 does not support a response body", + ): @get(path="/", status_code=HTTP_304_NOT_MODIFIED) def method_with_not_modified() -> Dict[str, str]: @@ -69,11 +68,10 @@ def method_with_not_modified() -> Dict[str, str]: Litestar(route_handlers=[method_with_not_modified]) - method_with_not_modified.on_registration( - Litestar(), HTTPRoute(path="/", route_handlers=[method_with_not_modified]) - ) - - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises( + ImproperlyConfiguredException, + match="A status code 204, 304 or in the range below 200 does not support a response body", + ): @get(path="/", status_code=HTTP_100_CONTINUE) def method_with_status_lower_than_200() -> Dict[str, str]: @@ -81,37 +79,29 @@ def method_with_status_lower_than_200() -> Dict[str, str]: Litestar(route_handlers=[method_with_status_lower_than_200]) - method_with_status_lower_than_200.on_registration( - Litestar(), HTTPRoute(path="/", route_handlers=[method_with_status_lower_than_200]) - ) - @get(path="/", status_code=HTTP_307_TEMPORARY_REDIRECT) def redirect_method() -> Redirect: return Redirect("/test") Litestar(route_handlers=[redirect_method]) - redirect_method.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[redirect_method])) - @get(path="/") def file_method() -> File: return File(path=Path("."), filename="test_validations.py") Litestar(route_handlers=[file_method]) - file_method.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[file_method])) - assert not file_method.media_type - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises(ImproperlyConfiguredException, match="The 'socket' kwarg is not supported"): @get(path="/test") def test_function_1(socket: WebSocket) -> None: return None - test_function_1.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[test_function_1])) + Litestar([test_function_1]) - with pytest.raises(ImproperlyConfiguredException): + with pytest.raises(ImproperlyConfiguredException, match="'data' kwarg is unsupported"): @get("/person") def test_function_2(self, data: DataclassPerson) -> None: # type: ignore[no-untyped-def] @@ -119,8 +109,6 @@ def test_function_2(self, data: DataclassPerson) -> None: # type: ignore[no-unt Litestar(route_handlers=[test_function_2]) - test_function_2.on_registration(Litestar(), HTTPRoute(path="/", route_handlers=[test_function_2])) - @pytest.mark.parametrize( ("return_annotation", "should_raise"), diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_handle_websocket.py b/tests/unit/test_handlers/test_websocket_handlers/test_handle_websocket.py index da4175158e..43fdb658ae 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_handle_websocket.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_handle_websocket.py @@ -1,6 +1,10 @@ from typing import List +from unittest.mock import MagicMock + +import pytest from litestar import Controller, Router, WebSocket, websocket +from litestar.exceptions import ImproperlyConfiguredException from litestar.testing import create_test_client @@ -49,3 +53,12 @@ async def simple_websocket_handler( ws.send_json({"data": "123"}) data = ws.receive_json() assert data == {"a": 1, "b": "two", "c": 3.0, "d": ["d"]} + + +async def test_not_finalized_raises() -> None: + @websocket("/") + async def handler(socket: WebSocket) -> None: + pass + + with pytest.raises(ImproperlyConfiguredException, match="handler parameter model not defined"): + await handler.handle(MagicMock()) diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py index a8e8c29934..7ef2ac1d88 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py @@ -239,7 +239,7 @@ def test_listener_callback_no_data_arg_raises() -> None: @websocket_listener("/") def handler() -> None: ... - handler.on_registration(Litestar(), WebSocketRoute(path="/", route_handler=handler)) + handler.on_registration(WebSocketRoute(path="/", route_handler=handler), app=Litestar()) def test_listener_callback_request_and_body_arg_raises() -> None: @@ -248,14 +248,14 @@ def test_listener_callback_request_and_body_arg_raises() -> None: @websocket_listener("/") def handler_request(data: str, request: Request) -> None: ... - handler_request.on_registration(Litestar(), WebSocketRoute(path="/", route_handler=handler_request)) + handler_request.on_registration(WebSocketRoute(path="/", route_handler=handler_request), app=Litestar()) with pytest.raises(ImproperlyConfiguredException): @websocket_listener("/") def handler_body(data: str, body: bytes) -> None: ... - handler_body.on_registration(Litestar(), WebSocketRoute(path="/", route_handler=handler_body)) + handler_body.on_registration(WebSocketRoute(path="/", route_handler=handler_body), app=Litestar()) def test_listener_accept_connection_callback() -> None: diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_resolution.py b/tests/unit/test_handlers/test_websocket_handlers/test_resolution.py new file mode 100644 index 0000000000..5fa654aa3f --- /dev/null +++ b/tests/unit/test_handlers/test_websocket_handlers/test_resolution.py @@ -0,0 +1,9 @@ +from litestar import WebSocket, websocket + + +def test_resolve_websocket_class() -> None: + @websocket() + async def handler(socket: WebSocket) -> None: + pass + + assert handler.resolve_websocket_class() is handler.websocket_class diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_validations.py b/tests/unit/test_handlers/test_websocket_handlers/test_validations.py index 57b8f19ce6..06c5a31f10 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_validations.py @@ -14,7 +14,7 @@ def fn_without_socket_arg(websocket: WebSocket) -> None: with pytest.raises(ImproperlyConfiguredException): handler = websocket(path="/")(fn_without_socket_arg) # type: ignore[arg-type] - handler.on_registration(Litestar(), WebSocketRoute(path="/", route_handler=handler)) + handler.on_registration(WebSocketRoute(path="/", route_handler=handler), app=Litestar()) def test_raises_for_return_annotation() -> None: @@ -23,7 +23,7 @@ async def fn_with_return_annotation(socket: WebSocket) -> dict: with pytest.raises(ImproperlyConfiguredException): handler = websocket(path="/")(fn_with_return_annotation) - handler.on_registration(Litestar(), WebSocketRoute(path="/", route_handler=handler)) + handler.on_registration(WebSocketRoute(path="/", route_handler=handler), app=Litestar()) def test_raises_when_no_function() -> None: @@ -40,7 +40,7 @@ def test_raises_when_sync_handler_user() -> None: def sync_websocket_handler(socket: WebSocket) -> None: ... sync_websocket_handler.on_registration( - Litestar(), WebSocketRoute(path="/", route_handler=sync_websocket_handler) + WebSocketRoute(path="/", route_handler=sync_websocket_handler), app=Litestar() ) @@ -51,7 +51,7 @@ def test_raises_when_data_kwarg_is_used() -> None: async def websocket_handler_with_data_kwarg(socket: WebSocket, data: Any) -> None: ... websocket_handler_with_data_kwarg.on_registration( - Litestar(), WebSocketRoute(path="/", route_handler=websocket_handler_with_data_kwarg) + WebSocketRoute(path="/", route_handler=websocket_handler_with_data_kwarg), app=Litestar() ) @@ -62,7 +62,7 @@ def test_raises_when_request_kwarg_is_used() -> None: async def websocket_handler_with_request_kwarg(socket: WebSocket, request: Any) -> None: ... websocket_handler_with_request_kwarg.on_registration( - Litestar(), WebSocketRoute(path="/", route_handler=websocket_handler_with_request_kwarg) + WebSocketRoute(path="/", route_handler=websocket_handler_with_request_kwarg), app=Litestar() ) @@ -73,5 +73,5 @@ def test_raises_when_body_kwarg_is_used() -> None: async def websocket_handler_with_request_kwarg(socket: WebSocket, body: bytes) -> None: ... websocket_handler_with_request_kwarg.on_registration( - Litestar(), WebSocketRoute(path="/", route_handler=websocket_handler_with_request_kwarg) + WebSocketRoute(path="/", route_handler=websocket_handler_with_request_kwarg), app=Litestar() ) diff --git a/tests/unit/test_openapi/conftest.py b/tests/unit/test_openapi/conftest.py index 332d2e3173..80d8c3ddbd 100644 --- a/tests/unit/test_openapi/conftest.py +++ b/tests/unit/test_openapi/conftest.py @@ -20,7 +20,7 @@ def create_person_controller() -> Type[Controller]: class PersonController(Controller): path = "/{service_id:int}/person" - @get(sync_to_thread=False) + @get("/", sync_to_thread=False) def get_persons( self, # expected to be ignored @@ -54,7 +54,7 @@ def get_persons( ) -> List[DataclassPerson]: return [] - @post(media_type=MediaType.TEXT, sync_to_thread=False) + @post("/", media_type=MediaType.TEXT, sync_to_thread=False) def create_person( self, data: DataclassPerson, secret_header: str = Parameter(header="secret") ) -> DataclassPerson: diff --git a/tests/unit/test_openapi/test_config.py b/tests/unit/test_openapi/test_config.py index a0832a2530..60f8bd5440 100644 --- a/tests/unit/test_openapi/test_config.py +++ b/tests/unit/test_openapi/test_config.py @@ -40,7 +40,7 @@ def test_merged_components_correct() -> None: def test_allows_customization_of_operation_id_creator() -> None: def operation_id_creator(handler: "HTTPRouteHandler", _: Any, __: Any) -> str: - return handler.name or "" + return f"id_{handler.name}" if handler.name else "" @get(path="/1", name="x") def handler_1() -> None: @@ -59,7 +59,7 @@ def handler_2() -> None: "/1": { "get": { "deprecated": False, - "operationId": "x", + "operationId": "id_x", "responses": {"200": {"description": "Request fulfilled, document follows", "headers": {}}}, "summary": "Handler1", } @@ -67,7 +67,7 @@ def handler_2() -> None: "/2": { "get": { "deprecated": False, - "operationId": "y", + "operationId": "id_y", "responses": {"200": {"description": "Request fulfilled, document follows", "headers": {}}}, "summary": "Handler2", } diff --git a/tests/unit/test_openapi/test_request_body.py b/tests/unit/test_openapi/test_request_body.py index 7a1452aa96..620d5c3744 100644 --- a/tests/unit/test_openapi/test_request_body.py +++ b/tests/unit/test_openapi/test_request_body.py @@ -41,7 +41,7 @@ def _factory(route_handler: BaseRouteHandler, data_field: FieldDefinition) -> Re return create_request_body( context=openapi_context, handler_id=route_handler.handler_id, - resolved_data_dto=route_handler.resolve_data_dto(), + resolved_data_dto=route_handler.data_dto, data_field=data_field, ) @@ -175,9 +175,10 @@ def test_request_body_generation_with_dto(create_request: RequestBodyFactory) -> async def handler(data: Dict[str, Any]) -> None: return None - Litestar(route_handlers=[handler]) + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/form-upload"]["POST"] field_definition = FieldDefinition.from_annotation(Dict[str, Any]) - create_request(handler, field_definition) + create_request(resolved_handler, field_definition) mock_dto.create_openapi_schema.assert_called_once_with( - field_definition=field_definition, handler_id=handler.handler_id, schema_creator=ANY + field_definition=field_definition, handler_id=resolved_handler.handler_id, schema_creator=ANY ) diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index ab626f0829..ba9cb4a674 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -74,7 +74,7 @@ def test_create_responses( for route in Litestar(route_handlers=[person_controller]).routes: assert isinstance(route, HTTPRoute) for route_handler in route.route_handler_map.values(): - if route_handler.resolve_include_in_schema(): + if route_handler.include_in_schema: responses = create_factory(route_handler).create_responses(True) assert responses assert str(route_handler.status_code) in responses @@ -520,13 +520,14 @@ def test_response_generation_with_dto(create_factory: CreateFactoryFixture) -> N async def handler(data: Dict[str, Any]) -> Dict[str, Any]: return data - Litestar(route_handlers=[handler]) + app = Litestar(route_handlers=[handler]) + resolved_handler = app.route_handler_method_map["/form-upload"]["POST"] - factory = create_factory(handler) + factory = create_factory(resolved_handler) field_definition = FieldDefinition.from_annotation(Dict[str, Any]) factory.create_success_response() mock_dto.create_openapi_schema.assert_called_once_with( - field_definition=field_definition, handler_id=handler.handler_id, schema_creator=factory.schema_creator + field_definition=field_definition, handler_id=resolved_handler.handler_id, schema_creator=factory.schema_creator ) diff --git a/tests/unit/test_request_class_resolution.py b/tests/unit/test_request_class_resolution.py index 081126c625..8b25d51555 100644 --- a/tests/unit/test_request_class_resolution.py +++ b/tests/unit/test_request_class_resolution.py @@ -12,14 +12,14 @@ @pytest.mark.parametrize( - "handler_request_class, controller_request_class, router_request_class, app_request_class, has_default_app_class, expected", + "handler_request_class, controller_request_class, router_request_class, app_request_class, expected", ( - (HandlerRequest, ControllerRequest, RouterRequest, AppRequest, True, HandlerRequest), - (None, ControllerRequest, RouterRequest, AppRequest, True, ControllerRequest), - (None, None, RouterRequest, AppRequest, True, RouterRequest), - (None, None, None, AppRequest, True, AppRequest), - (None, None, None, None, True, Request), - (None, None, None, None, False, Request), + (HandlerRequest, ControllerRequest, RouterRequest, AppRequest, HandlerRequest), + (None, ControllerRequest, RouterRequest, AppRequest, ControllerRequest), + (None, None, RouterRequest, AppRequest, RouterRequest), + (None, None, None, AppRequest, AppRequest), + (None, None, None, None, Request), + (None, None, None, None, Request), ), ids=( "Custom class for all layers", @@ -35,31 +35,20 @@ def test_request_class_resolution_of_layers( controller_request_class: Optional[Type[Request]], router_request_class: Optional[Type[Request]], app_request_class: Optional[Type[Request]], - has_default_app_class: bool, expected: Type[Request], ) -> None: class MyController(Controller): - @get() + request_class = controller_request_class + + @get(request_class=handler_request_class) def handler(self, request: Request) -> None: assert type(request) is expected - if controller_request_class: - MyController.request_class = ControllerRequest - - router = Router(path="/", route_handlers=[MyController]) - - if router_request_class: - router.request_class = router_request_class + router = Router(path="/", route_handlers=[MyController], request_class=router_request_class) - app = Litestar(route_handlers=[router]) - - if app_request_class or not has_default_app_class: - app.request_class = app_request_class # type: ignore[assignment] + app = Litestar(route_handlers=[router], request_class=app_request_class) route_handler: HTTPRouteHandler = app.route_handler_method_map["/"][HttpMethod.GET] # type: ignore[assignment] - if handler_request_class: - route_handler.request_class = handler_request_class - - request_class = route_handler.resolve_request_class() + request_class = route_handler.request_class assert request_class is expected diff --git a/tests/unit/test_response/test_response_cookies.py b/tests/unit/test_response/test_response_cookies.py index 3c5deaa2e6..5b5093b805 100644 --- a/tests/unit/test_response/test_response_cookies.py +++ b/tests/unit/test_response/test_response_cookies.py @@ -1,6 +1,6 @@ from uuid import uuid4 -from litestar import Controller, HttpMethod, Litestar, Response, Router, get +from litestar import Controller, Litestar, Response, Router, get from litestar.datastructures import Cookie from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client @@ -37,8 +37,8 @@ def test_method(self) -> None: response_cookies=[app_first, app_second], route_handlers=[first_router, second_router], ) - route_handler = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] - response_cookies = {cookie.key: cookie.value for cookie in route_handler.resolve_response_cookies()} + route_handler = app.routes[0].route_handler_map["GET"] # type: ignore[union-attr] + response_cookies = {cookie.key: cookie.value for cookie in route_handler.response_cookies} assert response_cookies["first"] == local_first.value assert response_cookies["second"] == controller_second.value assert response_cookies["third"] == router_second.value @@ -55,20 +55,7 @@ def handler_one() -> None: def handler_two() -> None: pass - assert handler_one.resolve_response_cookies() == handler_two.resolve_response_cookies() - - -def test_response_cookies_mapping_unresolved() -> None: - # this should never happen, as there's no way to create this situation which type-checks. - # we test for it nevertheless - - @get() - def handler_one() -> None: - pass - - handler_one.response_cookies = {"foo": "bar"} # type: ignore[assignment] - - assert handler_one.resolve_response_cookies() == frozenset([Cookie(key="foo", value="bar")]) + assert handler_one.resolve_response_cookies() == handler_two.response_cookies def test_response_cookie_rendering() -> None: diff --git a/tests/unit/test_response/test_response_headers.py b/tests/unit/test_response/test_response_headers.py index d7f7393da0..5555869cba 100644 --- a/tests/unit/test_response/test_response_headers.py +++ b/tests/unit/test_response/test_response_headers.py @@ -2,7 +2,7 @@ import pytest -from litestar import Controller, HttpMethod, Litestar, Router, get, post +from litestar import Controller, Litestar, Router, get, post from litestar.datastructures import CacheControlHeader, ETag, ResponseHeader from litestar.datastructures.headers import Header from litestar.status_codes import HTTP_201_CREATED @@ -38,8 +38,8 @@ def test_method(self) -> None: route_handlers=[first_router, second_router], ) - route_handler = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] - resolved_headers = {header.name: header for header in route_handler.resolve_response_headers()} + route_handler = app.routes[0].route_handler_map["GET"] # type: ignore[union-attr] + resolved_headers = {header.name: header for header in route_handler.response_headers} assert resolved_headers["first"].value == local_first.value assert resolved_headers["second"].value == controller_second.value assert resolved_headers["third"].value == router_second.value @@ -56,20 +56,7 @@ def handler_one() -> None: def handler_two() -> None: pass - assert handler_one.resolve_response_headers() == handler_two.resolve_response_headers() - - -def test_response_headers_mapping_unresolved() -> None: - # this should never happen, as there's no way to create this situation which type-checks. - # we test for it nevertheless - - @get() - def handler_one() -> None: - pass - - handler_one.response_headers = {"foo": "bar"} # type: ignore[assignment] - - assert handler_one.resolve_response_headers() == frozenset([ResponseHeader(name="foo", value="bar")]) + assert handler_one.response_headers == handler_two.response_headers def test_response_headers_rendering() -> None: @@ -137,7 +124,7 @@ def app_handler() -> None: "app": app_header, }.items(): response = client.get(path) - assert response.headers[expected_value.HEADER_NAME] == expected_value.to_header() + assert response.headers[expected_value.HEADER_NAME] == expected_value.to_header(), path @pytest.mark.parametrize( @@ -184,6 +171,6 @@ def my_handler() -> None: app = Litestar(route_handlers=[my_handler]) - route_handler = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] - resolved_headers = {header.name: header for header in route_handler.resolve_response_headers()} + route_handler = app.routes[0].route_handler_map["GET"] # type: ignore[union-attr] + resolved_headers = {header.name: header for header in route_handler.response_headers} assert resolved_headers[header.HEADER_NAME].value == header.to_header() diff --git a/tests/unit/test_response/test_response_to_asgi_response.py b/tests/unit/test_response/test_response_to_asgi_response.py index 4a23f69986..a0de936b23 100644 --- a/tests/unit/test_response/test_response_to_asgi_response.py +++ b/tests/unit/test_response/test_response_to_asgi_response.py @@ -1,18 +1,15 @@ from __future__ import annotations from inspect import iscoroutine -from json import loads from pathlib import Path from time import sleep from typing import TYPE_CHECKING, Any, Generator, Iterator, cast -import msgspec import pytest from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse from starlette.responses import Response as StarletteResponse -from litestar import HttpMethod, Litestar, MediaType, Request, Response, get, route -from litestar._signature import SignatureModel +from litestar import Litestar, MediaType, Request, Response, get from litestar.background_tasks import BackgroundTask from litestar.contrib.jinja import JinjaTemplateEngine from litestar.datastructures import Cookie, ResponseHeader @@ -27,8 +24,6 @@ from litestar.testing import RequestFactory, create_test_client from litestar.types import StreamType from litestar.utils import AsyncIteratorWrapper -from litestar.utils.signature import ParsedSignature -from tests.models import DataclassPerson, DataclassPersonFactory if TYPE_CHECKING: from typing import AsyncGenerator @@ -72,28 +67,6 @@ def __init__(self) -> None: super().__init__(iterator=MySyncIterator()) -async def test_to_response_async_await(anyio_backend: str) -> None: - @route(http_method=HttpMethod.POST, path="/person") - async def handler(data: DataclassPerson) -> DataclassPerson: - assert isinstance(data, DataclassPerson) - return data - - person_instance = DataclassPersonFactory.build() - handler._signature_model = SignatureModel.create( - dependency_name_set=set(), - fn=handler.fn, - data_dto=None, - parsed_signature=ParsedSignature.from_fn(handler.fn, {}), - type_decoders=[], - ) - - response = await handler.to_response( - data=handler.fn(data=person_instance), - request=RequestFactory(app=Litestar(route_handlers=[handler])).get(route_handler=handler), - ) - assert loads(response.body) == msgspec.to_builtins(person_instance) # type: ignore[attr-defined] - - async def test_to_response_returning_litestar_response() -> None: @get(path="/test") def handler() -> Response: diff --git a/tests/unit/test_response/test_type_decoders.py b/tests/unit/test_response/test_type_decoders.py index 3536481401..3d72716557 100644 --- a/tests/unit/test_response/test_type_decoders.py +++ b/tests/unit/test_response/test_type_decoders.py @@ -101,4 +101,4 @@ def test_resolve_type_decoders( path: str, method: Union[HttpMethod, Literal["websocket"]], type_decoders: TypeDecodersSequence, app: Litestar ) -> None: handler = app.route_handler_method_map[path][method] - assert handler.resolve_type_decoders() == type_decoders + assert handler.type_decoders == handler.resolve_type_decoders() == tuple(type_decoders) diff --git a/tests/unit/test_response/test_type_encoders.py b/tests/unit/test_response/test_type_encoders.py index ec59ee37a0..15ac0eef98 100644 --- a/tests/unit/test_response/test_type_encoders.py +++ b/tests/unit/test_response/test_type_encoders.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Tuple -from litestar import Controller, HttpMethod, Litestar, Response, Router, get +from litestar import Controller, Litestar, Response, Router, get from litestar.testing import create_test_client if TYPE_CHECKING: @@ -32,8 +32,8 @@ def handler(self) -> Any: ... router = Router("/router", type_encoders={router_type: router_encoder}, route_handlers=[MyController]) app = Litestar([router], type_encoders={app_type: app_encoder}) - route_handler = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] - encoders = route_handler.resolve_type_encoders() + route_handler = app.routes[0].route_handler_map["GET"] # type: ignore[union-attr] + encoders = route_handler.type_encoders assert encoders.get(handler_type) == handler_encoder assert encoders.get(controller_type) == controller_encoder assert encoders.get(router_type) == router_encoder diff --git a/tests/unit/test_response_class_resolution.py b/tests/unit/test_response_class_resolution.py index 0744a6d302..05ad58cd51 100644 --- a/tests/unit/test_response_class_resolution.py +++ b/tests/unit/test_response_class_resolution.py @@ -36,27 +36,17 @@ def test_response_class_resolution_of_layers( expected: Type[Response], ) -> None: class MyController(Controller): - @get() + response_class = controller_response_class + + @get(response_class=handler_response_class) def handler(self) -> None: pass - if controller_response_class: - MyController.response_class = ControllerResponse - - router = Router(path="/", route_handlers=[MyController]) - - if router_response_class: - router.response_class = router_response_class + router = Router(path="/", route_handlers=[MyController], response_class=router_response_class) - app = Litestar(route_handlers=[router]) - - if app_response_class: - app.response_class = app_response_class + app = Litestar(route_handlers=[router], response_class=app_response_class) route_handler: HTTPRouteHandler = app.route_handler_method_map["/"][HttpMethod.GET] # type: ignore[assignment] - if handler_response_class: - route_handler.response_class = handler_response_class - - response_class = route_handler.resolve_response_class() + response_class = route_handler.response_class assert response_class is expected diff --git a/tests/unit/test_static_files/test_create_static_router.py b/tests/unit/test_static_files/test_create_static_router.py index ffab6383ca..86102cd8cd 100644 --- a/tests/unit/test_static_files/test_create_static_router.py +++ b/tests/unit/test_static_files/test_create_static_router.py @@ -58,7 +58,7 @@ async def before_request(request: Request) -> Any: tags=tags, ) - assert router.guards == [guard] + assert router.guards == (guard,) assert router.exception_handlers == exception_handlers assert router.opt == opts assert router.after_request is after_request diff --git a/tests/unit/test_static_files/test_static_files_validation.py b/tests/unit/test_static_files/test_static_files_validation.py index 57ef6a4b59..bbe7ba90c9 100644 --- a/tests/unit/test_static_files/test_static_files_validation.py +++ b/tests/unit/test_static_files/test_static_files_validation.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import List import pytest @@ -9,7 +10,8 @@ from litestar.testing import create_test_client -def test_validation_of_directories() -> None: +@pytest.mark.parametrize("directories", [[], [""]]) +def test_validation_of_directories(directories: List[str]) -> None: with pytest.raises(ImproperlyConfiguredException): create_static_files_router(path="/static", directories=[]) diff --git a/tests/unit/test_utils/test_signature.py b/tests/unit/test_utils/test_signature.py index 6895ed79f5..23207d6025 100644 --- a/tests/unit/test_utils/test_signature.py +++ b/tests/unit/test_utils/test_signature.py @@ -11,7 +11,7 @@ import pytest from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_args, get_type_hints -from litestar import Controller, Router, post +from litestar import Controller, post from litestar.exceptions import ImproperlyConfiguredException from litestar.exceptions.base_exceptions import LitestarWarning from litestar.response.base import ASGIResponse @@ -205,8 +205,8 @@ def __class_getitem__(cls, model_class: type) -> type: cls_dict = {"model_class": model_class} return type(f"GenericController[{model_class.__name__}", (cls,), cls_dict) - def __init__(self, owner: Router) -> None: - super().__init__(owner) + def __init__(self) -> None: + super().__init__() self.signature_namespace[T] = self.model_class # type: ignore[misc] @@ -228,8 +228,8 @@ def test_using_generics_in_controller_annotations(annotation_type: type, expecte class ConcreteController(BaseController[annotation_type]): # type: ignore[valid-type] path = "/" - controller_object = ConcreteController(owner=None) # type: ignore[arg-type] + controller_object = ConcreteController() - signature = controller_object.get_route_handlers()[0].parsed_fn_signature + signature = (controller_object.get_route_handlers()[0]).merge(controller_object.as_router()).parsed_fn_signature actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation} assert actual == expected diff --git a/tests/unit/test_utils/test_sync.py b/tests/unit/test_utils/test_sync.py index 6e3be65c94..a1cfe3d104 100644 --- a/tests/unit/test_utils/test_sync.py +++ b/tests/unit/test_utils/test_sync.py @@ -32,10 +32,10 @@ async def my_method(self, value: int) -> None: wrapped_method = ensure_async_callable(instance.my_method) - await wrapped_method(1) # type: ignore[unused-coroutine] + await wrapped_method(1) assert instance.value == 1 - await wrapped_method(value=10) # type: ignore[unused-coroutine] + await wrapped_method(value=10) assert instance.value == 10 @@ -62,10 +62,10 @@ async def my_function(new_value: int) -> None: wrapped_function = ensure_async_callable(my_function) - await wrapped_function(1) # type: ignore[unused-coroutine] + await wrapped_function(1) assert obj["value"] == 1 - await wrapped_function(new_value=10) # type: ignore[unused-coroutine] + await wrapped_function(new_value=10) assert obj["value"] == 10 @@ -98,8 +98,8 @@ async def __call__(self, new_value: int) -> None: wrapped_class = ensure_async_callable(instance) - await wrapped_class(1) # type: ignore[unused-coroutine] + await wrapped_class(1) assert instance.value == 1 - await wrapped_class(new_value=10) # type: ignore[unused-coroutine] + await wrapped_class(new_value=10) assert instance.value == 10 diff --git a/tests/unit/test_websocket_class_resolution.py b/tests/unit/test_websocket_class_resolution.py index 4ec1967320..887ce7494b 100644 --- a/tests/unit/test_websocket_class_resolution.py +++ b/tests/unit/test_websocket_class_resolution.py @@ -39,40 +39,33 @@ def test_websocket_class_resolution_of_layers( expected: Type[WebSocket], ) -> None: class MyController(Controller): - @websocket_listener("/") + websocket_class = controller_websocket_class + + @websocket_listener("/", websocket_class=handler_websocket_class) def handler(self, data: str) -> None: return - if controller_websocket_class: - MyController.websocket_class = ControllerWebSocket - - router = Router(path="/", route_handlers=[MyController]) - - if router_websocket_class: - router.websocket_class = router_websocket_class + router = Router(path="/", route_handlers=[MyController], websocket_class=router_websocket_class) - app = Litestar(route_handlers=[router]) - - if app_websocket_class or not has_default_app_class: - app.websocket_class = app_websocket_class # type: ignore[assignment] + app = Litestar( + route_handlers=[router], + websocket_class=app_websocket_class if app_websocket_class or not has_default_app_class else None, + ) route_handler = app.routes[0].route_handler # type: ignore[union-attr] - if handler_websocket_class: - route_handler.websocket_class = handler_websocket_class # type: ignore[union-attr] - - websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + websocket_class = route_handler.websocket_class # type: ignore[union-attr] assert websocket_class is expected @pytest.mark.parametrize( - "handler_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected", + "handler_websocket_class, router_websocket_class, app_websocket_class, expected", ( - (HandlerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket), - (None, RouterWebSocket, AppWebSocket, True, RouterWebSocket), - (None, None, AppWebSocket, True, AppWebSocket), - (None, None, None, True, WebSocket), - (None, None, None, False, WebSocket), + (HandlerWebSocket, RouterWebSocket, AppWebSocket, HandlerWebSocket), + (None, RouterWebSocket, AppWebSocket, RouterWebSocket), + (None, None, AppWebSocket, AppWebSocket), + (None, None, None, WebSocket), + (None, None, None, WebSocket), ), ids=( "Custom class for all layers", @@ -86,7 +79,6 @@ def test_listener_websocket_class_resolution_of_layers( handler_websocket_class: Union[Type[WebSocket], None], router_websocket_class: Union[Type[WebSocket], None], app_websocket_class: Union[Type[WebSocket], None], - has_default_app_class: bool, expected: Type[WebSocket], ) -> None: class Handler(WebsocketListener): @@ -96,20 +88,10 @@ class Handler(WebsocketListener): def on_receive(self, data: str) -> str: # pyright: ignore return data - router = Router(path="/", route_handlers=[Handler]) - - if router_websocket_class: - router.websocket_class = router_websocket_class - - app = Litestar(route_handlers=[router]) - - if app_websocket_class or not has_default_app_class: - app.websocket_class = app_websocket_class # type: ignore[assignment] + router = Router(path="/", route_handlers=[Handler], websocket_class=router_websocket_class) + app = Litestar(route_handlers=[router], websocket_class=app_websocket_class) route_handler = app.routes[0].route_handler # type: ignore[union-attr] - if handler_websocket_class: - route_handler.websocket_class = handler_websocket_class # type: ignore[union-attr] - - websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + websocket_class = route_handler.websocket_class # type: ignore[union-attr] assert websocket_class is expected