1+ import asyncio
12import copy
23from collections .abc import MutableMapping
34from functools import partial
67from flask import Response , render_template_string , request
78from flask .views import View
89from graphql .error import GraphQLError
10+ from graphql .pyutils import is_awaitable
911from graphql .type .schema import GraphQLSchema
1012
1113from graphql_server import (
@@ -41,6 +43,7 @@ class GraphQLView(View):
4143 default_query = None
4244 header_editor_enabled = None
4345 should_persist_headers = None
46+ enable_async = False
4447
4548 methods = ["GET" , "POST" , "PUT" , "DELETE" ]
4649
@@ -53,26 +56,51 @@ def __init__(self, **kwargs):
5356 if hasattr (self , key ):
5457 setattr (self , key , value )
5558
56- assert isinstance (
57- self .schema , GraphQLSchema
58- ), "A Schema is required to be provided to GraphQLView."
59+ assert isinstance (self .schema , GraphQLSchema ), "A Schema is required to be provided to GraphQLView."
5960
6061 def get_root_value (self ):
6162 return self .root_value
6263
6364 def get_context (self ):
64- context = (
65- copy .copy (self .context )
66- if self .context and isinstance (self .context , MutableMapping )
67- else {}
68- )
65+ context = copy .copy (self .context ) if self .context and isinstance (self .context , MutableMapping ) else {}
6966 if isinstance (context , MutableMapping ) and "request" not in context :
7067 context .update ({"request" : request })
7168 return context
7269
7370 def get_middleware (self ):
7471 return self .middleware
7572
73+ def result_results (self , request_method , data , catch ):
74+ return run_http_query (
75+ self .schema ,
76+ request_method ,
77+ data ,
78+ query_data = request .args ,
79+ batch_enabled = self .batch ,
80+ catch = catch ,
81+ # Execute options
82+ root_value = self .get_root_value (),
83+ context_value = self .get_context (),
84+ middleware = self .get_middleware (),
85+ run_sync = not self .enable_async ,
86+ )
87+
88+ async def resolve_results_async (self , request_method , data , catch ):
89+ execution_results , all_params = run_http_query (
90+ self .schema ,
91+ request_method ,
92+ data ,
93+ query_data = request .args ,
94+ batch_enabled = self .batch ,
95+ catch = catch ,
96+ # Execute options
97+ root_value = self .get_root_value (),
98+ context_value = self .get_context (),
99+ middleware = self .get_middleware (),
100+ run_sync = not self .enable_async ,
101+ )
102+ return [await ex if is_awaitable (ex ) else ex for ex in execution_results ], all_params
103+
76104 def dispatch_request (self ):
77105 try :
78106 request_method = request .method .lower ()
@@ -84,18 +112,11 @@ def dispatch_request(self):
84112 pretty = self .pretty or show_graphiql or request .args .get ("pretty" )
85113
86114 all_params : List [GraphQLParams ]
87- execution_results , all_params = run_http_query (
88- self .schema ,
89- request_method ,
90- data ,
91- query_data = request .args ,
92- batch_enabled = self .batch ,
93- catch = catch ,
94- # Execute options
95- root_value = self .get_root_value (),
96- context_value = self .get_context (),
97- middleware = self .get_middleware (),
98- )
115+ if self .enable_async :
116+ execution_results , all_params = asyncio .run (self .resolve_results_async (request_method , data , catch ))
117+ else :
118+ execution_results , all_params = self .result_results (request_method , data , catch )
119+
99120 result , status_code = encode_execution_results (
100121 execution_results ,
101122 is_batch = isinstance (data , list ),
@@ -123,9 +144,7 @@ def dispatch_request(self):
123144 header_editor_enabled = self .header_editor_enabled ,
124145 should_persist_headers = self .should_persist_headers ,
125146 )
126- source = render_graphiql_sync (
127- data = graphiql_data , config = graphiql_config , options = graphiql_options
128- )
147+ source = render_graphiql_sync (data = graphiql_data , config = graphiql_config , options = graphiql_options )
129148 return render_template_string (source )
130149
131150 return Response (result , status = status_code , content_type = "application/json" )
@@ -150,10 +169,7 @@ def parse_body(self):
150169 elif content_type == "application/json" :
151170 return load_json_body (request .data .decode ("utf8" ))
152171
153- elif content_type in (
154- "application/x-www-form-urlencoded" ,
155- "multipart/form-data" ,
156- ):
172+ elif content_type in ("application/x-www-form-urlencoded" , "multipart/form-data" ,):
157173 return request .form
158174
159175 return {}
@@ -166,8 +182,4 @@ def should_display_graphiql(self):
166182
167183 def request_wants_html (self ):
168184 best = request .accept_mimetypes .best_match (["application/json" , "text/html" ])
169- return (
170- best == "text/html"
171- and request .accept_mimetypes [best ]
172- > request .accept_mimetypes ["application/json" ]
173- )
185+ return best == "text/html" and request .accept_mimetypes [best ] > request .accept_mimetypes ["application/json" ]
0 commit comments