diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 8870ff33..9afd6b75 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -1177,11 +1177,10 @@ async def complete_async_iterator_value( """ errors = async_payload_record.errors if async_payload_record else self.errors stream = self.get_stream_values(field_nodes, path) - is_awaitable = self.is_awaitable + complete_list_item_value = self.complete_list_item_value awaitable_indices: List[int] = [] append_awaitable = awaitable_indices.append completed_results: List[Any] = [] - append_result = completed_results.append index = 0 while True: if ( @@ -1213,46 +1212,23 @@ async def complete_async_iterator_value( value = await anext(iterator) except StopAsyncIteration: break - try: - completed_item = self.complete_value( - item_type, - field_nodes, - info, - item_path, - value, - async_payload_record, - ) - if is_awaitable(completed_item): - # noinspection PyShadowingNames - async def catch_error( - completed_item: Awaitable[Any], item_path: Path - ) -> Any: - try: - return await completed_item - except Exception as raw_error: - error = located_error( - raw_error, field_nodes, item_path.as_list() - ) - self.filter_subsequent_payloads( - item_path, async_payload_record - ) - handle_field_error(error, item_type, errors) - return None - - append_result(catch_error(completed_item, item_path)) - append_awaitable(index) - else: - append_result(completed_item) - except Exception as raw_error: - append_result(None) - error = located_error(raw_error, field_nodes, item_path.as_list()) - self.filter_subsequent_payloads(item_path, async_payload_record) - handle_field_error(error, item_type, errors) except Exception as raw_error: - append_result(None) error = located_error(raw_error, field_nodes, item_path.as_list()) handle_field_error(error, item_type, errors) + completed_results.append(None) break + if complete_list_item_value( + value, + completed_results, + errors, + item_type, + field_nodes, + info, + item_path, + async_payload_record, + ): + append_awaitable(index) + index += 1 if not awaitable_indices: @@ -1307,12 +1283,11 @@ def complete_list_value( # This is specified as a simple map, however we're optimizing the path where # the list contains no coroutine objects by avoiding creating another coroutine # object. - is_awaitable = self.is_awaitable + complete_list_item_value = self.complete_list_item_value awaitable_indices: List[int] = [] append_awaitable = awaitable_indices.append previous_async_payload_record = async_payload_record completed_results: List[Any] = [] - append_result = completed_results.append for index, item in enumerate(result): # No need to modify the info object containing the path, since from here on # it is not ever accessed by resolver functions. @@ -1335,67 +1310,17 @@ def complete_list_value( ) continue - completed_item: AwaitableOrValue[Any] - - if is_awaitable(item): - # noinspection PyShadowingNames - async def await_completed(item: Any, item_path: Path) -> Any: - try: - completed = self.complete_value( - item_type, - field_nodes, - info, - item_path, - await item, - async_payload_record, - ) - if is_awaitable(completed): - return await completed - except Exception as raw_error: - error = located_error( - raw_error, field_nodes, item_path.as_list() - ) - handle_field_error(error, item_type, errors) - self.filter_subsequent_payloads(item_path, async_payload_record) - return None - return completed - - completed_item = await_completed(item, item_path) - else: - try: - completed_item = self.complete_value( - item_type, - field_nodes, - info, - item_path, - item, - async_payload_record, - ) - if is_awaitable(completed_item): - # noinspection PyShadowingNames - async def await_completed(item: Any, item_path: Path) -> Any: - try: - return await item - except Exception as raw_error: - error = located_error( - raw_error, field_nodes, item_path.as_list() - ) - handle_field_error(error, item_type, errors) - self.filter_subsequent_payloads( - item_path, async_payload_record - ) - return None - - completed_item = await_completed(completed_item, item_path) - except Exception as raw_error: - error = located_error(raw_error, field_nodes, item_path.as_list()) - handle_field_error(error, item_type, errors) - self.filter_subsequent_payloads(item_path, async_payload_record) - completed_item = None - - if is_awaitable(completed_item): + if complete_list_item_value( + item, + completed_results, + errors, + item_type, + field_nodes, + info, + item_path, + async_payload_record, + ): append_awaitable(index) - append_result(completed_item) if not awaitable_indices: return completed_results @@ -1418,6 +1343,74 @@ async def get_completed_results() -> List[Any]: return get_completed_results() + def complete_list_item_value( + self, + item: Any, + complete_results: List[Any], + errors: List[GraphQLError], + item_type: GraphQLOutputType, + field_nodes: List[FieldNode], + info: GraphQLResolveInfo, + item_path: Path, + async_payload_record: Optional[AsyncPayloadRecord], + ) -> bool: + """Complete a list item value by adding it to the completed results. + + Returns True if the value is awaitable. + """ + is_awaitable = self.is_awaitable + try: + if is_awaitable(item): + completed_item: Any + + async def await_completed() -> Any: + completed = self.complete_value( + item_type, + field_nodes, + info, + item_path, + await item, + async_payload_record, + ) + return await completed if is_awaitable(completed) else completed + + completed_item = await_completed() + else: + completed_item = self.complete_value( + item_type, + field_nodes, + info, + item_path, + item, + async_payload_record, + ) + + if is_awaitable(completed_item): + # noinspection PyShadowingNames + async def catch_error() -> Any: + try: + return await completed_item + except Exception as raw_error: + error = located_error( + raw_error, field_nodes, item_path.as_list() + ) + handle_field_error(error, item_type, errors) + self.filter_subsequent_payloads(item_path, async_payload_record) + return None + + complete_results.append(catch_error()) + return True + + complete_results.append(completed_item) + + except Exception as raw_error: + error = located_error(raw_error, field_nodes, item_path.as_list()) + handle_field_error(error, item_type, errors) + self.filter_subsequent_payloads(item_path, async_payload_record) + complete_results.append(None) + + return False + @staticmethod def complete_leaf_value(return_type: GraphQLLeafType, result: Any) -> Any: """Complete a leaf value. diff --git a/tests/execution/test_stream.py b/tests/execution/test_stream.py index 9ab60f9f..84719bb9 100644 --- a/tests/execution/test_stream.py +++ b/tests/execution/test_stream.py @@ -1230,6 +1230,7 @@ async def friend_list(_info): } @pytest.mark.asyncio() + @pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning") async def does_not_filter_payloads_when_null_error_is_in_a_different_path(): document = parse( """