|  | 
|  | 1 | +from asyncio import Event, create_task, gather, sleep, wait_for | 
|  | 2 | +from typing import Callable | 
|  | 3 | + | 
|  | 4 | +import pytest | 
|  | 5 | + | 
|  | 6 | +from graphql.pyutils import gather_with_cancel, is_awaitable | 
|  | 7 | + | 
|  | 8 | + | 
|  | 9 | +class Controller: | 
|  | 10 | +    def reset(self, wait=False): | 
|  | 11 | +        self.event = Event() | 
|  | 12 | +        if not wait: | 
|  | 13 | +            self.event.set() | 
|  | 14 | +        self.returned = [] | 
|  | 15 | + | 
|  | 16 | + | 
|  | 17 | +controller = Controller() | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +async def coroutine(value: int) -> int: | 
|  | 21 | +    """Simple coroutine that returns a value.""" | 
|  | 22 | +    if value > 2: | 
|  | 23 | +        raise RuntimeError("Oops") | 
|  | 24 | +    await controller.event.wait() | 
|  | 25 | +    controller.returned.append(value) | 
|  | 26 | +    return value | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +class CustomAwaitable: | 
|  | 30 | +    """Custom awaitable that return a value.""" | 
|  | 31 | + | 
|  | 32 | +    def __init__(self, value: int): | 
|  | 33 | +        self.value = value | 
|  | 34 | +        self.coroutine = coroutine(value) | 
|  | 35 | + | 
|  | 36 | +    def __await__(self): | 
|  | 37 | +        return self.coroutine.__await__() | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +awaitable_factories: dict[str, Callable] = { | 
|  | 41 | +    "coroutine": coroutine, | 
|  | 42 | +    "task": lambda value: create_task(coroutine(value)), | 
|  | 43 | +    "custom": lambda value: CustomAwaitable(value), | 
|  | 44 | +} | 
|  | 45 | + | 
|  | 46 | +with_all_types_of_awaitables = pytest.mark.parametrize( | 
|  | 47 | +    "type_of_awaitable", awaitable_factories | 
|  | 48 | +) | 
|  | 49 | + | 
|  | 50 | + | 
|  | 51 | +def describe_gather_with_cancel(): | 
|  | 52 | +    @with_all_types_of_awaitables | 
|  | 53 | +    @pytest.mark.asyncio | 
|  | 54 | +    async def gathers_all_values(type_of_awaitable: str): | 
|  | 55 | +        factory = awaitable_factories[type_of_awaitable] | 
|  | 56 | +        values = list(range(3)) | 
|  | 57 | + | 
|  | 58 | +        controller.reset() | 
|  | 59 | +        aws = [factory(i) for i in values] | 
|  | 60 | + | 
|  | 61 | +        assert await gather(*aws) == values | 
|  | 62 | +        assert controller.returned == values | 
|  | 63 | + | 
|  | 64 | +        controller.reset() | 
|  | 65 | +        aws = [factory(i) for i in values] | 
|  | 66 | + | 
|  | 67 | +        result = gather_with_cancel(*aws) | 
|  | 68 | +        assert is_awaitable(result) | 
|  | 69 | + | 
|  | 70 | +        awaited = await wait_for(result, 1) | 
|  | 71 | +        assert awaited == values | 
|  | 72 | + | 
|  | 73 | +    @with_all_types_of_awaitables | 
|  | 74 | +    @pytest.mark.asyncio | 
|  | 75 | +    async def raises_on_exception(type_of_awaitable: str): | 
|  | 76 | +        factory = awaitable_factories[type_of_awaitable] | 
|  | 77 | +        values = list(range(4)) | 
|  | 78 | + | 
|  | 79 | +        controller.reset() | 
|  | 80 | +        aws = [factory(i) for i in values] | 
|  | 81 | + | 
|  | 82 | +        with pytest.raises(RuntimeError, match="Oops"): | 
|  | 83 | +            await gather(*aws) | 
|  | 84 | +        assert controller.returned == values[:-1] | 
|  | 85 | + | 
|  | 86 | +        controller.reset() | 
|  | 87 | +        aws = [factory(i) for i in values] | 
|  | 88 | + | 
|  | 89 | +        result = gather_with_cancel(*aws) | 
|  | 90 | +        assert is_awaitable(result) | 
|  | 91 | + | 
|  | 92 | +        with pytest.raises(RuntimeError, match="Oops"): | 
|  | 93 | +            await wait_for(result, 1) | 
|  | 94 | +        assert controller.returned == values[:-1] | 
|  | 95 | + | 
|  | 96 | +    @with_all_types_of_awaitables | 
|  | 97 | +    @pytest.mark.asyncio | 
|  | 98 | +    async def cancels_on_exception(type_of_awaitable: str): | 
|  | 99 | +        factory = awaitable_factories[type_of_awaitable] | 
|  | 100 | +        values = list(range(4)) | 
|  | 101 | + | 
|  | 102 | +        controller.reset(wait=True) | 
|  | 103 | +        aws = [factory(i) for i in values] | 
|  | 104 | + | 
|  | 105 | +        with pytest.raises(RuntimeError, match="Oops"): | 
|  | 106 | +            await gather(*aws) | 
|  | 107 | +        assert not controller.returned | 
|  | 108 | + | 
|  | 109 | +        # check that the standard gather continues to produce results | 
|  | 110 | +        controller.event.set() | 
|  | 111 | +        await sleep(0) | 
|  | 112 | +        assert controller.returned == values[:-1] | 
|  | 113 | + | 
|  | 114 | +        controller.reset(wait=True) | 
|  | 115 | +        aws = [factory(i) for i in values] | 
|  | 116 | + | 
|  | 117 | +        result = gather_with_cancel(*aws) | 
|  | 118 | +        assert is_awaitable(result) | 
|  | 119 | + | 
|  | 120 | +        with pytest.raises(RuntimeError, match="Oops"): | 
|  | 121 | +            await wait_for(result, 1) | 
|  | 122 | +        assert not controller.returned | 
|  | 123 | + | 
|  | 124 | +        # check that gather_with_cancel stops producing results | 
|  | 125 | +        controller.event.set() | 
|  | 126 | +        await sleep(0) | 
|  | 127 | +        if type_of_awaitable == "custom": | 
|  | 128 | +            # Cancellation of custom awaitables is not supported | 
|  | 129 | +            assert controller.returned == values[:-1] | 
|  | 130 | +        else: | 
|  | 131 | +            assert not controller.returned | 
0 commit comments