Skip to content

Commit bdb41ea

Browse files
author
Cameron Hurst
committed
fix: enable async cleaned up with hook points
1 parent c03e1a4 commit bdb41ea

File tree

2 files changed

+29
-40
lines changed

2 files changed

+29
-40
lines changed

graphql_server/__init__.py

+9-24
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from graphql.error import format_error as format_error_default
1616
from graphql.execution import ExecutionResult, execute
1717
from graphql.language import OperationType, parse
18-
from graphql.pyutils import AwaitableOrValue
18+
from graphql.pyutils import AwaitableOrValue, is_awaitable
1919
from graphql.type import GraphQLSchema, validate_schema
2020
from graphql.utilities import get_operation_ast
2121
from graphql.validation import ASTValidationRule, validate
@@ -99,9 +99,7 @@ def run_http_query(
9999

100100
if not is_batch:
101101
if not isinstance(data, (dict, MutableMapping)):
102-
raise HttpQueryError(
103-
400, f"GraphQL params should be a dict. Received {data!r}."
104-
)
102+
raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.")
105103
data = [data]
106104
elif not batch_enabled:
107105
raise HttpQueryError(400, "Batch GraphQL requests are not enabled.")
@@ -114,15 +112,10 @@ def run_http_query(
114112
if not is_batch:
115113
extra_data = query_data or {}
116114

117-
all_params: List[GraphQLParams] = [
118-
get_graphql_params(entry, extra_data) for entry in data
119-
]
115+
all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data]
120116

121117
results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [
122-
get_response(
123-
schema, params, catch_exc, allow_only_query, run_sync, **execute_options
124-
)
125-
for params in all_params
118+
get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params
126119
]
127120
return GraphQLResponse(results, all_params)
128121

@@ -160,10 +153,7 @@ def encode_execution_results(
160153
Returns a ServerResponse tuple with the serialized response as the first item and
161154
a status code of 200 or 400 in case any result was invalid as the second item.
162155
"""
163-
results = [
164-
format_execution_result(execution_result, format_error)
165-
for execution_result in execution_results
166-
]
156+
results = [format_execution_result(execution_result, format_error) for execution_result in execution_results]
167157
result, status_codes = zip(*results)
168158
status_code = max(status_codes)
169159

@@ -274,14 +264,11 @@ def get_response(
274264
if operation != OperationType.QUERY.value:
275265
raise HttpQueryError(
276266
405,
277-
f"Can only perform a {operation} operation"
278-
" from a POST request.",
267+
f"Can only perform a {operation} operation" " from a POST request.",
279268
headers={"Allow": "POST"},
280269
)
281270

282-
validation_errors = validate(
283-
schema, document, rules=validation_rules, max_errors=max_errors
284-
)
271+
validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors)
285272
if validation_errors:
286273
return ExecutionResult(data=None, errors=validation_errors)
287274

@@ -290,7 +277,7 @@ def get_response(
290277
document,
291278
variable_values=params.variables,
292279
operation_name=params.operation_name,
293-
is_awaitable=assume_not_awaitable if run_sync else None,
280+
is_awaitable=assume_not_awaitable if run_sync else is_awaitable,
294281
**kwargs,
295282
)
296283

@@ -317,9 +304,7 @@ def format_execution_result(
317304
fe = [format_error(e) for e in execution_result.errors] # type: ignore
318305
response = {"errors": fe}
319306

320-
if execution_result.errors and any(
321-
not getattr(e, "path", None) for e in execution_result.errors
322-
):
307+
if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors):
323308
status_code = 400
324309
else:
325310
response["data"] = execution_result.data

graphql_server/flask/graphqlview.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import asyncio
12
import copy
23
from collections.abc import MutableMapping
34
from functools import partial
45
from typing import List
56

67
from flask import Response, render_template_string, request
78
from flask.views import View
9+
from graphql import ExecutionResult
810
from graphql.error import GraphQLError
11+
from graphql.pyutils import is_awaitable
912
from graphql.type.schema import GraphQLSchema
1013

1114
from graphql_server import (
@@ -41,6 +44,7 @@ class GraphQLView(View):
4144
default_query = None
4245
header_editor_enabled = None
4346
should_persist_headers = None
47+
enable_async = True
4448

4549
methods = ["GET", "POST", "PUT", "DELETE"]
4650

@@ -53,26 +57,27 @@ def __init__(self, **kwargs):
5357
if hasattr(self, key):
5458
setattr(self, key, value)
5559

56-
assert isinstance(
57-
self.schema, GraphQLSchema
58-
), "A Schema is required to be provided to GraphQLView."
60+
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."
5961

6062
def get_root_value(self):
6163
return self.root_value
6264

6365
def get_context(self):
64-
context = (
65-
copy.copy(self.context)
66-
if self.context and isinstance(self.context, MutableMapping)
67-
else {}
68-
)
66+
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
6967
if isinstance(context, MutableMapping) and "request" not in context:
7068
context.update({"request": request})
7169
return context
7270

7371
def get_middleware(self):
7472
return self.middleware
7573

74+
@staticmethod
75+
def get_async_execution_results(execution_results):
76+
async def await_execution_results(execution_results):
77+
return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results]
78+
79+
return asyncio.run(await_execution_results(execution_results))
80+
7681
def dispatch_request(self):
7782
try:
7883
request_method = request.method.lower()
@@ -96,6 +101,11 @@ def dispatch_request(self):
96101
context_value=self.get_context(),
97102
middleware=self.get_middleware(),
98103
)
104+
105+
if self.enable_async:
106+
if any(is_awaitable(ex) for ex in execution_results):
107+
execution_results = self.get_async_execution_results(execution_results)
108+
99109
result, status_code = encode_execution_results(
100110
execution_results,
101111
is_batch=isinstance(data, list),
@@ -123,9 +133,7 @@ def dispatch_request(self):
123133
header_editor_enabled=self.header_editor_enabled,
124134
should_persist_headers=self.should_persist_headers,
125135
)
126-
source = render_graphiql_sync(
127-
data=graphiql_data, config=graphiql_config, options=graphiql_options
128-
)
136+
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
129137
return render_template_string(source)
130138

131139
return Response(result, status=status_code, content_type="application/json")
@@ -167,8 +175,4 @@ def should_display_graphiql(self):
167175
@staticmethod
168176
def request_wants_html():
169177
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
170-
return (
171-
best == "text/html"
172-
and request.accept_mimetypes[best]
173-
> request.accept_mimetypes["application/json"]
174-
)
178+
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]

0 commit comments

Comments
 (0)