Skip to content

Commit 1e77b8c

Browse files
committed
Use a more tolerant aclosing() context manager
1 parent e6bb12d commit 1e77b8c

File tree

6 files changed

+241
-71
lines changed

6 files changed

+241
-71
lines changed

src/graphql/execution/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
FormattedIncrementalResult,
3131
Middleware,
3232
)
33-
from .iterators import map_async_iterable
33+
from .async_iterables import flatten_async_iterable, map_async_iterable
3434
from .middleware import MiddlewareManager
3535
from .values import get_argument_values, get_directive_values, get_variable_values
3636

@@ -58,6 +58,7 @@
5858
"FormattedIncrementalDeferResult",
5959
"FormattedIncrementalStreamResult",
6060
"FormattedIncrementalResult",
61+
"flatten_async_iterable",
6162
"map_async_iterable",
6263
"Middleware",
6364
"MiddlewareManager",
Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations # Python < 3.10
22

3+
from contextlib import AbstractAsyncContextManager
34
from typing import (
45
Any,
56
AsyncGenerator,
@@ -11,25 +12,34 @@
1112
)
1213

1314

14-
try:
15-
from contextlib import aclosing
16-
except ImportError: # python < 3.10
17-
from contextlib import asynccontextmanager
18-
19-
@asynccontextmanager # type: ignore
20-
async def aclosing(thing):
21-
try:
22-
yield thing
23-
finally:
24-
await thing.aclose()
25-
15+
__all__ = ["aclosing", "flatten_async_iterable", "map_async_iterable"]
2616

2717
T = TypeVar("T")
2818
V = TypeVar("V")
2919

3020
AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]]
3121

32-
__all__ = ["flatten_async_iterable", "map_async_iterable"]
22+
23+
class aclosing(AbstractAsyncContextManager):
24+
"""Async context manager for safely finalizing an async iterator or generator.
25+
26+
Contrary to the function available via the standard library, this one silently
27+
ignores the case that custom iterators have no aclose() method.
28+
"""
29+
30+
def __init__(self, iterable: AsyncIterableOrGenerator[T]) -> None:
31+
self.iterable = iterable
32+
33+
async def __aenter__(self) -> AsyncIterableOrGenerator[T]:
34+
return self.iterable
35+
36+
async def __aexit__(self, *_exc_info: Any) -> None:
37+
try:
38+
aclose = self.iterable.aclose # type: ignore
39+
except AttributeError:
40+
pass # do not complain if the iterator has no aclose() method
41+
else:
42+
await aclose()
3343

3444

3545
async def flatten_async_iterable(
@@ -48,7 +58,7 @@ async def flatten_async_iterable(
4858

4959

5060
async def map_async_iterable(
51-
iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[V]]
61+
iterable: AsyncIterableOrGenerator[T], callback: Callable[[T], Awaitable[V]]
5262
) -> AsyncGenerator[V, None]:
5363
"""Map an AsyncIterable over a callback function.
5464
@@ -58,10 +68,6 @@ async def map_async_iterable(
5868
the generator finishes or closes.
5969
"""
6070

61-
aiter = iterable.__aiter__()
62-
try:
63-
async for element in aiter:
64-
yield await callback(element)
65-
finally:
66-
if hasattr(aiter, "aclose"):
67-
await aiter.aclose()
71+
async with aclosing(iterable) as items: # type: ignore
72+
async for item in items:
73+
yield await callback(item)

src/graphql/execution/execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@
7070
is_non_null_type,
7171
is_object_type,
7272
)
73+
from .async_iterables import flatten_async_iterable, map_async_iterable
7374
from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields
74-
from .iterators import flatten_async_iterable, map_async_iterable
7575
from .middleware import MiddlewareManager
7676
from .values import get_argument_values, get_directive_values, get_variable_values
7777

tests/execution/test_flatten_async_iterable.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytest import mark, raises
44

5-
from graphql.execution.iterators import flatten_async_iterable
5+
from graphql.execution import flatten_async_iterable
66

77

88
try: # pragma: no cover
@@ -129,3 +129,85 @@ async def nested2() -> AsyncGenerator[float, None]:
129129
assert await anext(doubles) == 2.2
130130
with raises(StopAsyncIteration):
131131
assert await anext(doubles)
132+
133+
@mark.asyncio
134+
async def closes_nested_async_iterators():
135+
closed = []
136+
137+
class Source:
138+
def __init__(self):
139+
self.counter = 0
140+
141+
def __aiter__(self):
142+
return self
143+
144+
async def __anext__(self):
145+
if self.counter == 2:
146+
raise StopAsyncIteration
147+
self.counter += 1
148+
return Nested(self.counter)
149+
150+
async def aclose(self):
151+
nonlocal closed
152+
closed.append(self.counter)
153+
154+
class Nested:
155+
def __init__(self, value):
156+
self.value = value
157+
self.counter = 0
158+
159+
def __aiter__(self):
160+
return self
161+
162+
async def __anext__(self):
163+
if self.counter == 2:
164+
raise StopAsyncIteration
165+
self.counter += 1
166+
return self.value + self.counter / 10
167+
168+
async def aclose(self):
169+
nonlocal closed
170+
closed.append(self.value + self.counter / 10)
171+
172+
doubles = flatten_async_iterable(Source())
173+
174+
result = [x async for x in doubles]
175+
176+
assert result == [1.1, 1.2, 2.1, 2.2]
177+
178+
assert closed == [1.2, 2.2, 2]
179+
180+
@mark.asyncio
181+
async def works_with_nested_async_iterators_that_have_no_close_method():
182+
class Source:
183+
def __init__(self):
184+
self.counter = 0
185+
186+
def __aiter__(self):
187+
return self
188+
189+
async def __anext__(self):
190+
if self.counter == 2:
191+
raise StopAsyncIteration
192+
self.counter += 1
193+
return Nested(self.counter)
194+
195+
class Nested:
196+
def __init__(self, value):
197+
self.value = value
198+
self.counter = 0
199+
200+
def __aiter__(self):
201+
return self
202+
203+
async def __anext__(self):
204+
if self.counter == 2:
205+
raise StopAsyncIteration
206+
self.counter += 1
207+
return self.value + self.counter / 10
208+
209+
doubles = flatten_async_iterable(Source())
210+
211+
result = [x async for x in doubles]
212+
213+
assert result == [1.1, 1.2, 2.1, 2.2]

tests/execution/test_map_async_iterable.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
from graphql.execution import map_async_iterable
44

55

6-
async def map_doubles(x):
6+
try: # pragma: no cover
7+
anext
8+
except NameError: # pragma: no cover (Python < 3.10)
9+
# noinspection PyShadowingBuiltins
10+
async def anext(iterator):
11+
"""Return the next item from an async iterator."""
12+
return await iterator.__anext__()
13+
14+
15+
async def map_doubles(x: int) -> int:
716
return x + x
817

918

1019
def describe_map_async_iterable():
1120
@mark.asyncio
12-
async def test_inner_close_called():
13-
"""
14-
Test that a custom iterator with aclose() gets an aclose() call
15-
when outer is closed
16-
"""
17-
21+
async def inner_is_closed_when_outer_is_closed():
1822
class Inner:
1923
def __init__(self):
2024
self.closed = False
@@ -30,19 +34,14 @@ async def __anext__(self):
3034

3135
inner = Inner()
3236
outer = map_async_iterable(inner, map_doubles)
33-
it = outer.__aiter__()
34-
assert await it.__anext__() == 2
37+
iterator = outer.__aiter__()
38+
assert await anext(iterator) == 2
3539
assert not inner.closed
3640
await outer.aclose()
3741
assert inner.closed
3842

3943
@mark.asyncio
40-
async def test_inner_close_called_on_callback_err():
41-
"""
42-
Test that a custom iterator with aclose() gets an aclose() call
43-
when the callback errors and the outer iterator aborts.
44-
"""
45-
44+
async def inner_is_closed_on_callback_error():
4645
class Inner:
4746
def __init__(self):
4847
self.closed = False
@@ -62,17 +61,11 @@ async def callback(v):
6261
inner = Inner()
6362
outer = map_async_iterable(inner, callback)
6463
with raises(RuntimeError):
65-
async for _ in outer:
66-
pass
64+
await anext(outer)
6765
assert inner.closed
6866

6967
@mark.asyncio
70-
async def test_inner_exit_on_callback_err():
71-
"""
72-
Test that a custom iterator with aclose() gets an aclose() call
73-
when the callback errors and the outer iterator aborts.
74-
"""
75-
68+
async def test_inner_exits_on_callback_error():
7669
inner_exit = False
7770

7871
async def inner():
@@ -88,6 +81,35 @@ async def callback(v):
8881

8982
outer = map_async_iterable(inner(), callback)
9083
with raises(RuntimeError):
91-
async for _ in outer:
92-
pass
84+
await anext(outer)
9385
assert inner_exit
86+
87+
@mark.asyncio
88+
async def inner_has_no_close_method_when_outer_is_closed():
89+
class Inner:
90+
def __aiter__(self):
91+
return self
92+
93+
async def __anext__(self):
94+
return 1
95+
96+
outer = map_async_iterable(Inner(), map_doubles)
97+
iterator = outer.__aiter__()
98+
assert await anext(iterator) == 2
99+
await outer.aclose()
100+
101+
@mark.asyncio
102+
async def inner_has_no_close_method_on_callback_error():
103+
class Inner:
104+
def __aiter__(self):
105+
return self
106+
107+
async def __anext__(self):
108+
return 1
109+
110+
async def callback(v):
111+
raise RuntimeError()
112+
113+
outer = map_async_iterable(Inner(), callback)
114+
with raises(RuntimeError):
115+
await anext(outer)

0 commit comments

Comments
 (0)