Skip to content

Commit ba64180

Browse files
committed
add cancellation support to async iterable iteration
GraphQL-core has no PromiseCanceller class: cancellableIterable becomes cancellable_iterable, wrapping each __anext__ in the existing with_abort_signal helper, used in complete_list_value and map_source_to_response. No mapAsyncIterable onDone callback is needed since with_abort_signal cancels its own per-step wait. Replicates graphql/graphql-js@46857e2
1 parent 302b3a6 commit ba64180

2 files changed

Lines changed: 226 additions & 7 deletions

File tree

src/graphql/execution/execute.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,34 @@ async def with_abort_signal(self, awaitable: Awaitable[T]) -> T:
964964
msg = f"Unexpected error value: {inspect(reason)}"
965965
raise TypeError(msg)
966966

967+
def cancellable_iterable(self, iterable: AsyncIterable[T]) -> AsyncIterable[T]:
968+
"""Wrap an async iterable so pending iteration is cancelled on abort.
969+
970+
When the abort signal is triggered, any pending ``__anext__`` call returns
971+
immediately by raising the abort reason. This mirrors the JavaScript
972+
``PromiseCanceller.cancellableIterable``; GraphQL-Core needs no
973+
``PromiseCanceller`` class since :meth:`with_abort_signal` already provides
974+
the cancellation mechanism.
975+
"""
976+
if self.abort_signal is None:
977+
return iterable
978+
with_abort_signal = self.with_abort_signal
979+
iterator = iterable.__aiter__()
980+
981+
class CancellableAsyncIterator:
982+
def __aiter__(self) -> AsyncIterator[T]:
983+
return self
984+
985+
def __anext__(self) -> Awaitable[T]:
986+
return with_abort_signal(iterator.__anext__())
987+
988+
async def aclose(self) -> None:
989+
aclose = getattr(iterator, "aclose", None)
990+
if aclose is not None:
991+
await aclose()
992+
993+
return CancellableAsyncIterator()
994+
967995
async def complete_awaitable_value(
968996
self,
969997
return_type: GraphQLOutputType,
@@ -1189,7 +1217,7 @@ def complete_list_value(
11891217
item_type = return_type.of_type
11901218

11911219
if self.is_async_iterable(result):
1192-
async_iterator = result.__aiter__()
1220+
async_iterator = self.cancellable_iterable(result).__aiter__()
11931221

11941222
return self.complete_async_iterator_value(
11951223
item_type,
@@ -1724,7 +1752,7 @@ async def callback(payload: Any) -> ExecutionResult:
17241752
else cast("ExecutionResult", result)
17251753
)
17261754

1727-
return map_async_iterable(result_or_stream, callback)
1755+
return map_async_iterable(self.cancellable_iterable(result_or_stream), callback)
17281756

17291757
def collect_execution_groups(
17301758
self,

tests/execution/test_abort_signal.py

Lines changed: 196 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from asyncio import Event, Future, ensure_future, sleep
6-
from collections.abc import Awaitable
6+
from collections.abc import AsyncIterator, Awaitable
77

88
import pytest
99

@@ -415,6 +415,102 @@ def todo(_info):
415415
],
416416
)
417417

418+
async def stops_the_execution_when_aborted_despite_a_hanging_async_item():
419+
abort_controller = AbortController()
420+
document = parse(
421+
"""
422+
query {
423+
todo {
424+
id
425+
items
426+
}
427+
}
428+
"""
429+
)
430+
431+
async def items(_info):
432+
# never reached: the iterator is cancelled before its body runs
433+
yield await Future() # will never resolve # pragma: no cover
434+
435+
def todo(_info):
436+
return {"id": "1", "items": items}
437+
438+
awaitable_result = execute(
439+
schema,
440+
document,
441+
abort_signal=abort_controller.signal,
442+
root_value={"todo": todo},
443+
)
444+
assert isinstance(awaitable_result, Awaitable)
445+
446+
abort_controller.abort()
447+
448+
result = await awaitable_result
449+
450+
assert result.errors is not None
451+
assert isinstance(result.errors[0].original_error, AbortError)
452+
assert result == (
453+
{"todo": {"id": "1", "items": None}},
454+
[
455+
{
456+
"message": "This operation was aborted",
457+
"locations": [(5, 11)],
458+
"path": ["todo", "items"],
459+
}
460+
],
461+
)
462+
463+
async def stops_the_execution_when_aborted_despite_a_hanging_iterator_no_close():
464+
# Like the test above, but the async iterator has no aclose() method, so
465+
# the cancellable wrapper has nothing to forward the close to.
466+
abort_controller = AbortController()
467+
document = parse(
468+
"""
469+
query {
470+
todo {
471+
id
472+
items
473+
}
474+
}
475+
"""
476+
)
477+
478+
class Items:
479+
def __aiter__(self):
480+
return self
481+
482+
async def __anext__(self):
483+
# never reached: the iterator is cancelled before its body runs
484+
return await Future() # will never resolve # pragma: no cover
485+
486+
def todo(_info):
487+
return {"id": "1", "items": Items()}
488+
489+
awaitable_result = execute(
490+
schema,
491+
document,
492+
abort_signal=abort_controller.signal,
493+
root_value={"todo": todo},
494+
)
495+
assert isinstance(awaitable_result, Awaitable)
496+
497+
abort_controller.abort()
498+
499+
result = await awaitable_result
500+
501+
assert result.errors is not None
502+
assert isinstance(result.errors[0].original_error, AbortError)
503+
assert result == (
504+
{"todo": {"id": "1", "items": None}},
505+
[
506+
{
507+
"message": "This operation was aborted",
508+
"locations": [(5, 11)],
509+
"path": ["todo", "items"],
510+
}
511+
],
512+
)
513+
418514
async def stops_the_execution_when_aborted_with_proper_null_bubbling():
419515
abort_controller = AbortController()
420516
document = parse(
@@ -532,7 +628,7 @@ async def stops_the_execution_when_aborted_pre_execute():
532628

533629
assert result == (None, [{"message": "This operation was aborted"}])
534630

535-
async def stops_the_execution_when_aborted_during_subscription():
631+
async def stops_the_execution_when_aborted_prior_to_return_of_subscription():
536632
abort_controller = AbortController()
537633
document = parse(
538634
"""
@@ -545,17 +641,17 @@ async def stops_the_execution_when_aborted_during_subscription():
545641
def foo(_info):
546642
return Future() # will never resolve
547643

548-
awaitable_result = subscribe(
644+
subscription_promise = subscribe(
549645
schema,
550646
document,
551647
abort_signal=abort_controller.signal,
552648
root_value={"foo": foo},
553649
)
554-
assert isinstance(awaitable_result, Awaitable)
650+
assert isinstance(subscription_promise, Awaitable)
555651

556652
abort_controller.abort()
557653

558-
result = await awaitable_result
654+
result = await subscription_promise
559655

560656
assert result == (
561657
None,
@@ -567,3 +663,98 @@ def foo(_info):
567663
}
568664
],
569665
)
666+
667+
async def successfully_wraps_the_subscription():
668+
abort_controller = AbortController()
669+
document = parse(
670+
"""
671+
subscription {
672+
foo
673+
}
674+
"""
675+
)
676+
677+
async def foo():
678+
yield {"foo": "foo"}
679+
680+
async def resolve_foo(_info):
681+
return foo()
682+
683+
subscription_promise = subscribe(
684+
schema,
685+
document,
686+
abort_signal=abort_controller.signal,
687+
root_value={"foo": resolve_foo},
688+
)
689+
assert isinstance(subscription_promise, Awaitable)
690+
subscription = await subscription_promise
691+
692+
assert isinstance(subscription, AsyncIterator)
693+
694+
assert await anext(subscription) == ({"foo": "foo"}, None)
695+
696+
with pytest.raises(StopAsyncIteration):
697+
await anext(subscription)
698+
699+
async def stops_the_execution_when_aborted_during_subscription():
700+
abort_controller = AbortController()
701+
document = parse(
702+
"""
703+
subscription {
704+
foo
705+
}
706+
"""
707+
)
708+
709+
async def foo():
710+
yield {"foo": "foo"}
711+
712+
subscription = subscribe(
713+
schema,
714+
document,
715+
abort_signal=abort_controller.signal,
716+
root_value={"foo": foo()},
717+
)
718+
719+
assert isinstance(subscription, AsyncIterator)
720+
721+
assert await anext(subscription) == ({"foo": "foo"}, None)
722+
723+
abort_controller.abort()
724+
725+
with pytest.raises(AbortError, match="This operation was aborted"):
726+
await anext(subscription)
727+
728+
async def stops_the_execution_when_aborted_during_async_returned_subscription():
729+
abort_controller = AbortController()
730+
document = parse(
731+
"""
732+
subscription {
733+
foo
734+
}
735+
"""
736+
)
737+
738+
async def foo():
739+
yield {"foo": "foo"}
740+
741+
async def resolve_foo(_info):
742+
return foo()
743+
744+
subscription_promise = subscribe(
745+
schema,
746+
document,
747+
abort_signal=abort_controller.signal,
748+
root_value={"foo": resolve_foo},
749+
)
750+
assert isinstance(subscription_promise, Awaitable)
751+
subscription = await subscription_promise
752+
753+
assert isinstance(subscription, AsyncIterator)
754+
755+
assert await anext(subscription) == ({"foo": "foo"}, None)
756+
757+
abort_controller.abort()
758+
759+
with pytest.raises(AbortError, match="This operation was aborted"):
760+
await anext(subscription)

0 commit comments

Comments
 (0)