From 43e3041193d74fe9b2a1c2d1e327729f67468f24 Mon Sep 17 00:00:00 2001 From: kedod <35638715+kedod@users.noreply.github.com> Date: Sat, 6 Apr 2024 23:48:26 +0200 Subject: [PATCH] feat: Add async `websocket_connect` to `AsyncTestClient` (#3328) feat: Add `websocket_connect` method to AsyncTestClient Co-authored-by: kedod --- litestar/testing/client/async_client.py | 59 ++++++++++++++++++++- tests/unit/test_testing/test_test_client.py | 51 +++++++++++++++++- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index cf66f12f47..0e4d779170 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -1,14 +1,15 @@ from __future__ import annotations from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar +from urllib.parse import urljoin from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response from litestar import HttpMethod from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler -from litestar.testing.transport import TestClientTransport +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport from litestar.types import AnyIOBackend, ASGIApp if TYPE_CHECKING: @@ -27,6 +28,7 @@ from typing_extensions import Self from litestar.middleware.session.base import BaseBackendConfig + from litestar.testing.websocket_test_session import WebSocketTestSession T = TypeVar("T", bound=ASGIApp) @@ -468,6 +470,59 @@ async def delete( extensions=None if extensions is None else dict(extensions), ) + async def websocket_connect( + self, + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> WebSocketTestSession: + """Sends a GET request to establish a websocket connection. + + Args: + url: Request URL. + subprotocols: Websocket subprotocols. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + A `WebSocketTestSession ` instance. + """ + url = urljoin("ws://testserver", url) + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + try: + await AsyncClient.request( + self, + "GET", + url, + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + except ConnectionUpgradeExceptionError as exc: + return exc.session + + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + async def get_session_data(self) -> dict[str, Any]: """Get session data. diff --git a/tests/unit/test_testing/test_test_client.py b/tests/unit/test_testing/test_test_client.py index a555be4ddd..31840e2220 100644 --- a/tests/unit/test_testing/test_test_client.py +++ b/tests/unit/test_testing/test_test_client.py @@ -5,7 +5,7 @@ from litestar import Controller, WebSocket, delete, head, patch, put, websocket from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT -from litestar.testing import AsyncTestClient, WebSocketTestSession, create_test_client +from litestar.testing import AsyncTestClient, WebSocketTestSession, create_async_test_client, create_test_client if TYPE_CHECKING: from litestar.middleware.session.base import BaseBackendConfig @@ -261,3 +261,52 @@ async def handler(socket: WebSocket) -> None: Empty ), client.websocket_connect("/"): pass + + +@pytest.mark.parametrize("block,timeout", [(False, None), (False, 0.001), (True, 0.001)]) +@pytest.mark.parametrize( + "receive_method", + [ + WebSocketTestSession.receive, + WebSocketTestSession.receive_json, + WebSocketTestSession.receive_text, + WebSocketTestSession.receive_bytes, + ], +) +async def test_websocket_test_session_block_timeout_async( + receive_method: Callable[..., Any], block: bool, timeout: Optional[float], anyio_backend: "AnyIOBackend" +) -> None: + @websocket() + async def handler(socket: WebSocket) -> None: + await socket.accept() + + with pytest.raises(Empty): + async with create_async_test_client(handler, backend=anyio_backend) as client: + with await client.websocket_connect("/") as ws: + receive_method(ws, timeout=timeout, block=block) + + +async def test_websocket_accept_timeout_async(anyio_backend: "AnyIOBackend") -> None: + @websocket() + async def handler(socket: WebSocket) -> None: + pass + + async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client: + with pytest.raises(Empty): + with await client.websocket_connect("/"): + pass + + +async def test_websocket_connect_async(anyio_backend: "AnyIOBackend") -> None: + @websocket() + async def handler(socket: WebSocket) -> None: + await socket.accept() + data = await socket.receive_json() + await socket.send_json(data) + await socket.close() + + async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client: + with await client.websocket_connect("/", subprotocols="wamp") as ws: + ws.send_json({"data": "123"}) + data = ws.receive_json() + assert data == {"data": "123"}