2
2
import functools
3
3
import logging
4
4
import sys
5
+ from rx import Observable
5
6
6
7
from six import string_types
7
8
from promise import Promise , promise_for_dict , is_thenable
15
16
GraphQLSchema , GraphQLUnionType )
16
17
from .base import (ExecutionContext , ExecutionResult , ResolveInfo ,
17
18
collect_fields , default_resolve_fn , get_field_def ,
18
- get_operation_root_type )
19
+ get_operation_root_type , SubscriberExecutionContext )
19
20
from .executors .sync import SyncExecutor
20
21
from .middleware import MiddlewareManager
21
22
22
23
logger = logging .getLogger (__name__ )
23
24
24
25
26
+ def subscribe (* args , ** kwargs ):
27
+ allow_subscriptions = kwargs .pop ('allow_subscriptions' , True )
28
+ return execute (* args , allow_subscriptions = allow_subscriptions , ** kwargs )
29
+
30
+
25
31
def execute (schema , document_ast , root_value = None , context_value = None ,
26
32
variable_values = None , operation_name = None , executor = None ,
27
- return_promise = False , middleware = None ):
33
+ return_promise = False , middleware = None , allow_subscriptions = False ):
28
34
assert schema , 'Must provide schema'
29
35
assert isinstance (schema , GraphQLSchema ), (
30
36
'Schema must be an instance of GraphQLSchema. Also ensure that there are ' +
@@ -50,7 +56,8 @@ def execute(schema, document_ast, root_value=None, context_value=None,
50
56
variable_values ,
51
57
operation_name ,
52
58
executor ,
53
- middleware
59
+ middleware ,
60
+ allow_subscriptions
54
61
)
55
62
56
63
def executor (v ):
@@ -61,6 +68,9 @@ def on_rejected(error):
61
68
return None
62
69
63
70
def on_resolve (data ):
71
+ if isinstance (data , Observable ):
72
+ return data
73
+
64
74
if not context .errors :
65
75
return ExecutionResult (data = data )
66
76
return ExecutionResult (data = data , errors = context .errors )
@@ -88,6 +98,15 @@ def execute_operation(exe_context, operation, root_value):
88
98
if operation .operation == 'mutation' :
89
99
return execute_fields_serially (exe_context , type , root_value , fields )
90
100
101
+ if operation .operation == 'subscription' :
102
+ if not exe_context .allow_subscriptions :
103
+ raise Exception (
104
+ "Subscriptions are not allowed. "
105
+ "You will need to either use the subscribe function "
106
+ "or pass allow_subscriptions=True"
107
+ )
108
+ return subscribe_fields (exe_context , type , root_value , fields )
109
+
91
110
return execute_fields (exe_context , type , root_value , fields )
92
111
93
112
@@ -140,6 +159,44 @@ def execute_fields(exe_context, parent_type, source_value, fields):
140
159
return promise_for_dict (final_results )
141
160
142
161
162
+ def subscribe_fields (exe_context , parent_type , source_value , fields ):
163
+ exe_context = SubscriberExecutionContext (exe_context )
164
+
165
+ def on_error (error ):
166
+ exe_context .report_error (error )
167
+
168
+ def map_result (data ):
169
+ if exe_context .errors :
170
+ result = ExecutionResult (data = data , errors = exe_context .errors )
171
+ else :
172
+ result = ExecutionResult (data = data )
173
+ exe_context .reset ()
174
+ return result
175
+
176
+ observables = []
177
+
178
+ # assert len(fields) == 1, "Can only subscribe one element at a time."
179
+
180
+ for response_name , field_asts in fields .items ():
181
+
182
+ result = subscribe_field (exe_context , parent_type ,
183
+ source_value , field_asts )
184
+ if result is Undefined :
185
+ continue
186
+
187
+ def catch_error (error ):
188
+ exe_context .errors .append (error )
189
+ return Observable .just (None )
190
+
191
+ # Map observable results
192
+ observable = result .catch_exception (catch_error ).map (
193
+ lambda data : map_result ({response_name : data }))
194
+ return observable
195
+ observables .append (observable )
196
+
197
+ return Observable .merge (observables )
198
+
199
+
143
200
def resolve_field (exe_context , parent_type , source , field_asts ):
144
201
field_ast = field_asts [0 ]
145
202
field_name = field_ast .name .value
@@ -191,6 +248,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
191
248
)
192
249
193
250
251
+ def subscribe_field (exe_context , parent_type , source , field_asts ):
252
+ field_ast = field_asts [0 ]
253
+ field_name = field_ast .name .value
254
+
255
+ field_def = get_field_def (exe_context .schema , parent_type , field_name )
256
+ if not field_def :
257
+ return Undefined
258
+
259
+ return_type = field_def .type
260
+ resolve_fn = field_def .resolver or default_resolve_fn
261
+
262
+ # We wrap the resolve_fn from the middleware
263
+ resolve_fn_middleware = exe_context .get_field_resolver (resolve_fn )
264
+
265
+ # Build a dict of arguments from the field.arguments AST, using the variables scope to
266
+ # fulfill any variable references.
267
+ args = exe_context .get_argument_values (field_def , field_ast )
268
+
269
+ # The resolve function's optional third argument is a context value that
270
+ # is provided to every resolve function within an execution. It is commonly
271
+ # used to represent an authenticated user, or request-specific caches.
272
+ context = exe_context .context_value
273
+
274
+ # The resolve function's optional third argument is a collection of
275
+ # information about the current execution state.
276
+ info = ResolveInfo (
277
+ field_name ,
278
+ field_asts ,
279
+ return_type ,
280
+ parent_type ,
281
+ schema = exe_context .schema ,
282
+ fragments = exe_context .fragments ,
283
+ root_value = exe_context .root_value ,
284
+ operation = exe_context .operation ,
285
+ variable_values = exe_context .variable_values ,
286
+ context = context
287
+ )
288
+
289
+ executor = exe_context .executor
290
+ result = resolve_or_error (resolve_fn_middleware ,
291
+ source , info , args , executor )
292
+
293
+ if isinstance (result , Exception ):
294
+ raise result
295
+
296
+ if not isinstance (result , Observable ):
297
+ raise GraphQLError (
298
+ 'Subscription must return Async Iterable or Observable. Received: {}' .format (repr (result )))
299
+
300
+ return result .map (functools .partial (
301
+ complete_value_catching_error ,
302
+ exe_context ,
303
+ return_type ,
304
+ field_asts ,
305
+ info ,
306
+ ))
307
+
308
+
194
309
def resolve_or_error (resolve_fn , source , info , args , executor ):
195
310
try :
196
311
return executor .execute (resolve_fn , source , info , ** args )
0 commit comments