|  | 
|  | 1 | +from __future__ import annotations | 
|  | 2 | + | 
|  | 3 | +import asyncio | 
|  | 4 | + | 
|  | 5 | +import pytest | 
|  | 6 | + | 
|  | 7 | +from graphql_server.channels.handlers.base import ChannelsConsumer | 
|  | 8 | + | 
|  | 9 | + | 
|  | 10 | +class DummyChannelLayer: | 
|  | 11 | +    def __init__(self) -> None: | 
|  | 12 | +        self.added: list[tuple[str, str]] = [] | 
|  | 13 | +        self.discarded: list[tuple[str, str]] = [] | 
|  | 14 | + | 
|  | 15 | +    async def group_add(self, group: str, channel: str) -> None: | 
|  | 16 | +        self.added.append((group, channel)) | 
|  | 17 | + | 
|  | 18 | +    async def group_discard(self, group: str, channel: str) -> None: | 
|  | 19 | +        self.discarded.append((group, channel)) | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +@pytest.mark.asyncio | 
|  | 23 | +async def test_channel_listen_receives_messages_and_cleans_up() -> None: | 
|  | 24 | +    consumer = ChannelsConsumer() | 
|  | 25 | +    layer = DummyChannelLayer() | 
|  | 26 | +    consumer.channel_layer = layer | 
|  | 27 | +    consumer.channel_name = "chan" | 
|  | 28 | + | 
|  | 29 | +    gen = consumer.channel_listen("test.message", groups=["g"], timeout=0.1) | 
|  | 30 | + | 
|  | 31 | +    async def send() -> None: | 
|  | 32 | +        await asyncio.sleep(0) | 
|  | 33 | +        queue = next(iter(consumer.listen_queues["test.message"])) | 
|  | 34 | +        queue.put_nowait({"type": "test.message", "payload": 1}) | 
|  | 35 | + | 
|  | 36 | +    asyncio.create_task(send()) | 
|  | 37 | + | 
|  | 38 | +    with pytest.deprecated_call(match="Use listen_to_channel instead"): | 
|  | 39 | +        message = await gen.__anext__() | 
|  | 40 | +    assert message == {"type": "test.message", "payload": 1} | 
|  | 41 | + | 
|  | 42 | +    await gen.aclose() | 
|  | 43 | + | 
|  | 44 | +    assert layer.added == [("g", "chan")] | 
|  | 45 | +    assert layer.discarded == [("g", "chan")] | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +@pytest.mark.asyncio | 
|  | 49 | +async def test_channel_listen_times_out() -> None: | 
|  | 50 | +    consumer = ChannelsConsumer() | 
|  | 51 | +    layer = DummyChannelLayer() | 
|  | 52 | +    consumer.channel_layer = layer | 
|  | 53 | +    consumer.channel_name = "chan" | 
|  | 54 | + | 
|  | 55 | +    gen = consumer.channel_listen("test.message", groups=["g"], timeout=0.01) | 
|  | 56 | + | 
|  | 57 | +    with pytest.deprecated_call(match="Use listen_to_channel instead"): | 
|  | 58 | +        with pytest.raises(StopAsyncIteration): | 
|  | 59 | +            await gen.__anext__() | 
|  | 60 | + | 
|  | 61 | +    assert layer.added == [("g", "chan")] | 
|  | 62 | +    assert layer.discarded == [("g", "chan")] | 
0 commit comments