Skip to content

Commit 0b0551e

Browse files
authored
Merge pull request #143 from graphql-python/features/observable-subscriptions
[WIP] Subscriptions
2 parents 305eb80 + d85382d commit 0b0551e

File tree

14 files changed

+782
-59
lines changed

14 files changed

+782
-59
lines changed

.travis.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@ language: python
22
sudo: false
33
python:
44
- 2.7
5-
- 3.4
6-
- 3.5
7-
- 3.6
8-
- "pypy-5.3.1"
5+
# - "pypy-5.3.1"
96
before_install:
107
- |
118
if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then
@@ -22,7 +19,9 @@ before_install:
2219
fi
2320
install:
2421
- pip install -e .[test]
22+
- pip install flake8
2523
script:
24+
- flake8
2625
- py.test --cov=graphql graphql tests
2726
after_success:
2827
- coveralls
@@ -33,10 +32,13 @@ matrix:
3332
- pip install pytest-asyncio
3433
script:
3534
- py.test --cov=graphql graphql tests tests_py35
36-
- python: '2.7'
37-
install: pip install flake8
35+
- python: '3.6'
36+
after_install:
37+
- pip install pytest-asyncio
3838
script:
39-
- flake8
39+
- py.test --cov=graphql graphql tests tests_py35
40+
- python: '2.7'
41+
4042
deploy:
4143
provider: pypi
4244
user: syrusakbary

graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
# Execute GraphQL queries.
121121
from .execution import ( # no import order
122122
execute,
123+
subscribe,
123124
ResolveInfo,
124125
MiddlewareManager,
125126
middlewares
@@ -254,6 +255,7 @@
254255
'print_ast',
255256
'visit',
256257
'execute',
258+
'subscribe',
257259
'ResolveInfo',
258260
'MiddlewareManager',
259261
'middlewares',

graphql/execution/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
2) fragment "spreads" e.g. "...c"
1919
3) inline fragment "spreads" e.g. "...on Type { a }"
2020
"""
21-
from .executor import execute
21+
from .executor import execute, subscribe
2222
from .base import ExecutionResult, ResolveInfo
2323
from .middleware import middlewares, MiddlewareManager
2424

2525

2626
__all__ = [
2727
'execute',
28+
'subscribe',
2829
'ExecutionResult',
2930
'ResolveInfo',
3031
'MiddlewareManager',

graphql/execution/base.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class ExecutionContext(object):
1919
and the fragments defined in the query document"""
2020

2121
__slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \
22-
'argument_values_cache', 'executor', 'middleware', '_subfields_cache'
22+
'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache'
2323

24-
def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware):
24+
def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions):
2525
"""Constructs a ExecutionContext object from the arguments passed
2626
to execute, which we will pass throughout the other execution
2727
methods."""
@@ -32,7 +32,8 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
3232
for definition in document_ast.definitions:
3333
if isinstance(definition, ast.OperationDefinition):
3434
if not operation_name and operation:
35-
raise GraphQLError('Must provide operation name if query contains multiple operations.')
35+
raise GraphQLError(
36+
'Must provide operation name if query contains multiple operations.')
3637

3738
if not operation_name or definition.name and definition.name.value == operation_name:
3839
operation = definition
@@ -42,18 +43,21 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
4243

4344
else:
4445
raise GraphQLError(
45-
u'GraphQL cannot execute a request containing a {}.'.format(definition.__class__.__name__),
46+
u'GraphQL cannot execute a request containing a {}.'.format(
47+
definition.__class__.__name__),
4648
definition
4749
)
4850

4951
if not operation:
5052
if operation_name:
51-
raise GraphQLError(u'Unknown operation named "{}".'.format(operation_name))
53+
raise GraphQLError(
54+
u'Unknown operation named "{}".'.format(operation_name))
5255

5356
else:
5457
raise GraphQLError('Must provide an operation.')
5558

56-
variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values)
59+
variable_values = get_variable_values(
60+
schema, operation.variable_definitions or [], variable_values)
5761

5862
self.schema = schema
5963
self.fragments = fragments
@@ -65,6 +69,7 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
6569
self.argument_values_cache = {}
6670
self.executor = executor
6771
self.middleware = middleware
72+
self.allow_subscriptions = allow_subscriptions
6873
self._subfields_cache = {}
6974

7075
def get_field_resolver(self, field_resolver):
@@ -82,7 +87,8 @@ def get_argument_values(self, field_def, field_ast):
8287
return result
8388

8489
def report_error(self, error, traceback=None):
85-
sys.excepthook(type(error), error, getattr(error, 'stack', None) or traceback)
90+
sys.excepthook(type(error), error, getattr(
91+
error, 'stack', None) or traceback)
8692
self.errors.append(error)
8793

8894
def get_sub_fields(self, return_type, field_asts):
@@ -101,6 +107,20 @@ def get_sub_fields(self, return_type, field_asts):
101107
return self._subfields_cache[k]
102108

103109

110+
class SubscriberExecutionContext(object):
111+
__slots__ = 'exe_context', 'errors'
112+
113+
def __init__(self, exe_context):
114+
self.exe_context = exe_context
115+
self.errors = []
116+
117+
def reset(self):
118+
self.errors = []
119+
120+
def __getattr__(self, name):
121+
return getattr(self.exe_context, name)
122+
123+
104124
class ExecutionResult(object):
105125
"""The result of execution. `data` is the result of executing the
106126
query, `errors` is null if no errors occurred, and is a
@@ -186,7 +206,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
186206
ctx, selection, runtime_type):
187207
continue
188208

189-
collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names)
209+
collect_fields(ctx, runtime_type,
210+
selection.selection_set, fields, prev_fragment_names)
190211

191212
elif isinstance(selection, ast.FragmentSpread):
192213
frag_name = selection.name.value
@@ -202,7 +223,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
202223
does_fragment_condition_match(ctx, fragment, runtime_type):
203224
continue
204225

205-
collect_fields(ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names)
226+
collect_fields(ctx, runtime_type,
227+
fragment.selection_set, fields, prev_fragment_names)
206228

207229
return fields
208230

graphql/execution/executor.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import logging
44
import sys
5+
from rx import Observable
56

67
from six import string_types
78
from promise import Promise, promise_for_dict, is_thenable
@@ -15,16 +16,21 @@
1516
GraphQLSchema, GraphQLUnionType)
1617
from .base import (ExecutionContext, ExecutionResult, ResolveInfo,
1718
collect_fields, default_resolve_fn, get_field_def,
18-
get_operation_root_type)
19+
get_operation_root_type, SubscriberExecutionContext)
1920
from .executors.sync import SyncExecutor
2021
from .middleware import MiddlewareManager
2122

2223
logger = logging.getLogger(__name__)
2324

2425

26+
def subscribe(*args, **kwargs):
27+
allow_subscriptions = kwargs.pop('allow_subscriptions', True)
28+
return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs)
29+
30+
2531
def execute(schema, document_ast, root_value=None, context_value=None,
2632
variable_values=None, operation_name=None, executor=None,
27-
return_promise=False, middleware=None):
33+
return_promise=False, middleware=None, allow_subscriptions=False):
2834
assert schema, 'Must provide schema'
2935
assert isinstance(schema, GraphQLSchema), (
3036
'Schema must be an instance of GraphQLSchema. Also ensure that there are ' +
@@ -50,7 +56,8 @@ def execute(schema, document_ast, root_value=None, context_value=None,
5056
variable_values,
5157
operation_name,
5258
executor,
53-
middleware
59+
middleware,
60+
allow_subscriptions
5461
)
5562

5663
def executor(v):
@@ -61,6 +68,9 @@ def on_rejected(error):
6168
return None
6269

6370
def on_resolve(data):
71+
if isinstance(data, Observable):
72+
return data
73+
6474
if not context.errors:
6575
return ExecutionResult(data=data)
6676
return ExecutionResult(data=data, errors=context.errors)
@@ -88,6 +98,15 @@ def execute_operation(exe_context, operation, root_value):
8898
if operation.operation == 'mutation':
8999
return execute_fields_serially(exe_context, type, root_value, fields)
90100

101+
if operation.operation == 'subscription':
102+
if not exe_context.allow_subscriptions:
103+
raise Exception(
104+
"Subscriptions are not allowed. "
105+
"You will need to either use the subscribe function "
106+
"or pass allow_subscriptions=True"
107+
)
108+
return subscribe_fields(exe_context, type, root_value, fields)
109+
91110
return execute_fields(exe_context, type, root_value, fields)
92111

93112

@@ -140,6 +159,44 @@ def execute_fields(exe_context, parent_type, source_value, fields):
140159
return promise_for_dict(final_results)
141160

142161

162+
def subscribe_fields(exe_context, parent_type, source_value, fields):
163+
exe_context = SubscriberExecutionContext(exe_context)
164+
165+
def on_error(error):
166+
exe_context.report_error(error)
167+
168+
def map_result(data):
169+
if exe_context.errors:
170+
result = ExecutionResult(data=data, errors=exe_context.errors)
171+
else:
172+
result = ExecutionResult(data=data)
173+
exe_context.reset()
174+
return result
175+
176+
observables = []
177+
178+
# assert len(fields) == 1, "Can only subscribe one element at a time."
179+
180+
for response_name, field_asts in fields.items():
181+
182+
result = subscribe_field(exe_context, parent_type,
183+
source_value, field_asts)
184+
if result is Undefined:
185+
continue
186+
187+
def catch_error(error):
188+
exe_context.errors.append(error)
189+
return Observable.just(None)
190+
191+
# Map observable results
192+
observable = result.catch_exception(catch_error).map(
193+
lambda data: map_result({response_name: data}))
194+
return observable
195+
observables.append(observable)
196+
197+
return Observable.merge(observables)
198+
199+
143200
def resolve_field(exe_context, parent_type, source, field_asts):
144201
field_ast = field_asts[0]
145202
field_name = field_ast.name.value
@@ -191,6 +248,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
191248
)
192249

193250

251+
def subscribe_field(exe_context, parent_type, source, field_asts):
252+
field_ast = field_asts[0]
253+
field_name = field_ast.name.value
254+
255+
field_def = get_field_def(exe_context.schema, parent_type, field_name)
256+
if not field_def:
257+
return Undefined
258+
259+
return_type = field_def.type
260+
resolve_fn = field_def.resolver or default_resolve_fn
261+
262+
# We wrap the resolve_fn from the middleware
263+
resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn)
264+
265+
# Build a dict of arguments from the field.arguments AST, using the variables scope to
266+
# fulfill any variable references.
267+
args = exe_context.get_argument_values(field_def, field_ast)
268+
269+
# The resolve function's optional third argument is a context value that
270+
# is provided to every resolve function within an execution. It is commonly
271+
# used to represent an authenticated user, or request-specific caches.
272+
context = exe_context.context_value
273+
274+
# The resolve function's optional third argument is a collection of
275+
# information about the current execution state.
276+
info = ResolveInfo(
277+
field_name,
278+
field_asts,
279+
return_type,
280+
parent_type,
281+
schema=exe_context.schema,
282+
fragments=exe_context.fragments,
283+
root_value=exe_context.root_value,
284+
operation=exe_context.operation,
285+
variable_values=exe_context.variable_values,
286+
context=context
287+
)
288+
289+
executor = exe_context.executor
290+
result = resolve_or_error(resolve_fn_middleware,
291+
source, info, args, executor)
292+
293+
if isinstance(result, Exception):
294+
raise result
295+
296+
if not isinstance(result, Observable):
297+
raise GraphQLError(
298+
'Subscription must return Async Iterable or Observable. Received: {}'.format(repr(result)))
299+
300+
return result.map(functools.partial(
301+
complete_value_catching_error,
302+
exe_context,
303+
return_type,
304+
field_asts,
305+
info,
306+
))
307+
308+
194309
def resolve_or_error(resolve_fn, source, info, args, executor):
195310
try:
196311
return executor.execute(resolve_fn, source, info, **args)

graphql/execution/executors/asyncio.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@ def ensure_future(coro_or_future, loop=None):
2525
del task._source_traceback[-1]
2626
return task
2727
else:
28-
raise TypeError('A Future, a coroutine or an awaitable is required')
28+
raise TypeError(
29+
'A Future, a coroutine or an awaitable is required')
30+
31+
try:
32+
from .asyncio_utils import asyncgen_to_observable, isasyncgen
33+
except Exception:
34+
def isasyncgen(obj): False
35+
36+
def asyncgen_to_observable(asyncgen): pass
2937

3038

3139
class AsyncioExecutor(object):
@@ -50,4 +58,6 @@ def execute(self, fn, *args, **kwargs):
5058
future = ensure_future(result, loop=self.loop)
5159
self.futures.append(future)
5260
return Promise.resolve(future)
61+
elif isasyncgen(result):
62+
return asyncgen_to_observable(result)
5363
return result

0 commit comments

Comments
 (0)