diff --git a/README.md b/README.md index 338bb9b..43d108c 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,13 @@ Currently supports: # Installation instructions -For instaling graphql-ws, just run this command in your shell +For installing graphql-ws, just run this command in your shell ```bash pip install graphql-ws ``` -## Examples +## Subscription Server ### aiohttp @@ -63,90 +63,167 @@ async def subscriptions(request, ws): app.run(host="0.0.0.0", port=8000) ``` - -And then, plug into a subscribable schema: +### Gevent +For setting up, just plug into your Gevent server. ```python -import asyncio -import graphene +subscription_server = GeventSubscriptionServer(schema) +app.app_protocol = lambda environ_path_info: 'graphql-ws' +@sockets.route('/subscriptions') +def echo_socket(ws): + subscription_server.handle(ws) + return [] +``` +### Django (with channels) -class Query(graphene.ObjectType): - base = graphene.String() +First `pip install channels` and add it to your django apps +Then add the following to your settings.py -class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) +```python + CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] + CHANNEL_LAYERS = { + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "django_subscriptions.urls.channel_routing", + }, - async def resolve_count_seconds(root, info, up_to): - for i in range(up_to): - yield i - await asyncio.sleep(1.) - yield up_to + } +``` + +Add the channel routes to your Django server. +```python +from channels.routing import route_class +from graphql_ws.django_channels import GraphQLSubscriptionConsumer -schema = graphene.Schema(query=Query, subscription=Subscription) +channel_routing = [ + route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), +] ``` -You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp +## Publish-Subscribe +Included are several publish-subscribe (pubsub) classes for hooking +up your mutations to your subscriptions. When a client makes a +subscription, the pubsub can be used to map from one subscription name +to one or more channel names to subscribe to the right channels. +The subscription query will be re-run every time something is +published to one of these channels. Using these classes, a +subscription is just the result of a mutation. -### Gevent +### Asyncio -For setting up, just plug into your Gevent server. +There are two pubsub classes for asyncio, one that is in-memory and the other +that utilizes Redis (for production), via the [aredis](https://github.com/NoneGG/aredis) libary, which +is a asynchronous port of the excellent [redis-py](https://github.com/andymccurdy/redis-py) library. + +The schema for asyncio would look something like this below: ```python -subscription_server = GeventSubscriptionServer(schema) -app.app_protocol = lambda environ_path_info: 'graphql-ws' +import asyncio +import graphene -@sockets.route('/subscriptions') -def echo_socket(ws): - subscription_server.handle(ws) - return [] +from graphql_ws.pubsub import AsyncioPubsub + +# create a new pubsub object; this class is in-memory and does +# not utilze Redis +pubsub = AsyncioPubsub() + + +class MutationExample(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + async def mutate(self, info, input_text): + # publish to the pubsub object before returning mutation + await pubsub.publish('BASE', input_text) + return MutationExample(output_text=input_text) + + +class Mutations(graphene.ObjectType): + mutation_example = MutationExample.Field() + + +class Subscription(graphene.ObjectType): + mutation_example = graphene.String() + + async def resolve_mutation_example(root, info): + try: + # pubsub subscribe_to_channel method returns + # subscription id and an asyncio.Queue + sub_id, q = pubsub.subscribe_to_channel('BASE') + while True: + payload = await q.get() + yield payload + except asyncio.CancelledError: + # unsubscribe subscription id from channel + # when coroutine is cancelled + pubsub.unsubscribe('BASE', sub_id) + +schema = graphene.Schema(mutation=Mutations, + subscription=Subscription) ``` -And then, plug into a subscribable schema: +You can see a full asyncio example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp + +### Gevent + +There are two pubsub classes for Gevent as well, one that is +in-memory and the other that utilizes Redis (for production), via +[redis-py](https://github.com/andymccurdy/redis-py). + +Finally, plug into a subscribable schema: ```python import graphene + +from graphql_ws.pubsub import GeventRxRedisPubsub from rx import Observable +# create a new pubsub object; in the case you'll need to +# be running a redis-server instance in a separate process +pubsub = GeventRxRedisPubsub() -class Query(graphene.ObjectType): - base = graphene.String() +class MutationExample(graphene.Mutation): + class Arguments: + input_text = graphene.String() -class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) + output_text = graphene.String() - async def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) + def mutate(self, info, input_text): + # publish to the pubsub before returning mutation + pubsub.publish('BASE', input_text) + return MutationExample(output_text=input_text) -schema = graphene.Schema(query=Query, subscription=Subscription) -``` +class Mutations(graphene.ObjectType): + mutation_example = MutationExample.Field() -You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent +class Subscription(graphene.ObjectType): + mutation_example = graphene.String() -### Django Channels + def resolve_mutation_example(root, info): + # pubsub subscribe_to_channel method returns an observable + # when observable is disposed of, the subscription will + # be cleaned up and unsubscribed from + return pubsub.subscribe_to_channel('BASE')\ + .map(lambda i: "{0}".format(i)) -First `pip install channels` and it to your django apps +schema = graphene.Schema(mutation=Mutations, + subscription=Subscription) +``` -Then add the following to your settings.py +You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent -```python - CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] - CHANNEL_LAYERS = { - "default": { - "BACKEND": "asgiref.inmemory.ChannelLayer", - "ROUTING": "django_subscriptions.urls.channel_routing", - }, - } -``` +### Django (with channels) + Setup your graphql schema @@ -167,8 +244,8 @@ class Subscription(graphene.ObjectType): def resolve_count_seconds( - root, - info, + root, + info, up_to=5 ): return Observable.interval(1000)\ @@ -192,14 +269,3 @@ GRAPHENE = { 'SCHEMA': 'path.to.schema' } ``` - -and finally add the channel routes - -```python -from channels.routing import route_class -from graphql_ws.django_channels import GraphQLSubscriptionConsumer - -channel_routing = [ - route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), -] -``` \ No newline at end of file diff --git a/examples/aiohttp/schema.py b/examples/aiohttp/schema.py index 3c23d00..351d0bd 100644 --- a/examples/aiohttp/schema.py +++ b/examples/aiohttp/schema.py @@ -2,10 +2,32 @@ import asyncio import graphene +from graphql_ws.pubsub import AsyncioPubsub + +pubsub = AsyncioPubsub() + class Query(graphene.ObjectType): base = graphene.String() + async def resolve_base(root, info): + return 'Hello World!' + + +class MutationExample(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + async def mutate(self, info, input_text): + await pubsub.publish('BASE', input_text) + return MutationExample(output_text=input_text) + + +class Mutations(graphene.ObjectType): + mutation_example = MutationExample.Field() + class RandomType(graphene.ObjectType): seconds = graphene.Int() @@ -15,6 +37,16 @@ class RandomType(graphene.ObjectType): class Subscription(graphene.ObjectType): count_seconds = graphene.Float(up_to=graphene.Int()) random_int = graphene.Field(RandomType) + mutation_example = graphene.String() + + async def resolve_mutation_example(root, info): + try: + sub_id, q = pubsub.subscribe_to_channel('BASE') + while True: + payload = await q.get() + yield payload + finally: + pubsub.unsubscribe('BASE', sub_id) async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): @@ -31,4 +63,5 @@ async def resolve_random_int(root, info): i += 1 -schema = graphene.Schema(query=Query, subscription=Subscription) +schema = graphene.Schema(query=Query, mutation=Mutations, + subscription=Subscription) diff --git a/examples/flask_gevent/app.py b/examples/flask_gevent/app.py index dbb0cca..4d822a4 100644 --- a/examples/flask_gevent/app.py +++ b/examples/flask_gevent/app.py @@ -20,7 +20,8 @@ def graphql_view(): app.add_url_rule( - '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False)) + '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, + graphiql=False)) subscription_server = GeventSubscriptionServer(schema) app.app_protocol = lambda environ_path_info: 'graphql-ws' diff --git a/examples/flask_gevent/schema.py b/examples/flask_gevent/schema.py index 6e6298c..669d270 100644 --- a/examples/flask_gevent/schema.py +++ b/examples/flask_gevent/schema.py @@ -1,11 +1,33 @@ -import random import graphene +import random + +from graphql_ws.pubsub import GeventRxPubsub from rx import Observable +pubsub = GeventRxPubsub() + class Query(graphene.ObjectType): base = graphene.String() + def resolve_base(root, info): + return 'Hello World!' + + +class MutationExample(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + def mutate(self, info, input_text): + pubsub.publish('BASE', input_text) + return MutationExample(output_text=input_text) + + +class Mutations(graphene.ObjectType): + mutation_example = MutationExample.Field() + class RandomType(graphene.ObjectType): seconds = graphene.Int() @@ -13,10 +35,14 @@ class RandomType(graphene.ObjectType): class Subscription(graphene.ObjectType): - count_seconds = graphene.Int(up_to=graphene.Int()) - random_int = graphene.Field(RandomType) + mutation_example = graphene.String() + + def resolve_mutation_example(root, info): + # subscribe_to_channel method returns an observable + return pubsub.subscribe_to_channel('BASE')\ + .map(lambda i: "{0}".format(i)) def resolve_count_seconds(root, info, up_to=5): return Observable.interval(1000)\ @@ -24,7 +50,9 @@ def resolve_count_seconds(root, info, up_to=5): .take_while(lambda i: int(i) <= up_to) def resolve_random_int(root, info): - return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) + return Observable.interval(1000).map( + lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) -schema = graphene.Schema(query=Query, subscription=Subscription) +schema = graphene.Schema(query=Query, mutation=Mutations, + subscription=Subscription) diff --git a/examples/flask_gevent/template.py b/examples/flask_gevent/template.py index e7e0d6a..ef905a8 100644 --- a/examples/flask_gevent/template.py +++ b/examples/flask_gevent/template.py @@ -117,9 +117,8 @@ def render_graphiql(): </script> </body> </html>''').substitute( - GRAPHIQL_VERSION='0.11.7', + GRAPHIQL_VERSION='0.10.2', SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', subscriptionsEndpoint='ws://localhost:5000/subscriptions', - # subscriptionsEndpoint='ws://localhost:5000/', endpointURL='/graphql', ) diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py index 0de6988..33b333c 100644 --- a/examples/websockets_lib/app.py +++ b/examples/websockets_lib/app.py @@ -1,3 +1,5 @@ +import asyncio + from graphql_ws.websockets_lib import WsLibSubscriptionServer from graphql.execution.executors.asyncio import AsyncioExecutor from sanic import Sanic, response @@ -9,12 +11,25 @@ @app.listener('before_server_start') -def init_graphql(app, loop): +async def init_graphql(app, loop): app.add_route(GraphQLView.as_view(schema=schema, executor=AsyncioExecutor(loop=loop)), '/graphql') +@app.listener('before_server_stop') +async def cleanup_subscription_tasks(app, loop): + # clean up tasks created by subscriptions and pubsub + def shutdown_exception_handler(loop, context): + if "exception" not in context or not isinstance( + context["exception"], asyncio.CancelledError): + loop.default_exception_handler(context) + loop.set_exception_handler(shutdown_exception_handler) + pending = asyncio.Task.all_tasks(loop=loop) + future = asyncio.gather(*pending, loop=loop, return_exceptions=True) + future.cancel() + + @app.route('/graphiql') async def graphiql_view(request): return response.html(render_graphiql()) diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py index 3c23d00..1a5f591 100644 --- a/examples/websockets_lib/schema.py +++ b/examples/websockets_lib/schema.py @@ -2,10 +2,32 @@ import asyncio import graphene +from graphql_ws.pubsub import AsyncioPubsub + +pubsub = AsyncioPubsub() + class Query(graphene.ObjectType): base = graphene.String() + async def resolve_base(root, info): + return 'Hello World!' + + +class MutationExample(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + async def mutate(root, info, input_text): + await pubsub.publish('BASE', input_text) + return MutationExample(output_text=input_text) + + +class Mutations(graphene.ObjectType): + mutation_example = MutationExample.Field() + class RandomType(graphene.ObjectType): seconds = graphene.Int() @@ -15,6 +37,16 @@ class RandomType(graphene.ObjectType): class Subscription(graphene.ObjectType): count_seconds = graphene.Float(up_to=graphene.Int()) random_int = graphene.Field(RandomType) + mutation_example = graphene.String() + + async def resolve_mutation_example(root, info): + try: + sub_id, q = pubsub.subscribe_to_channel('BASE') + while True: + payload = await q.get() + yield payload + finally: + pubsub.unsubscribe('BASE', sub_id) async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): @@ -31,4 +63,5 @@ async def resolve_random_int(root, info): i += 1 -schema = graphene.Schema(query=Query, subscription=Subscription) +schema = graphene.Schema(query=Query, mutation=Mutations, + subscription=Subscription) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 4af5720..e17bf74 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,4 +1,4 @@ -from inspect import isawaitable, isasyncgen +from inspect import isawaitable from asyncio import ensure_future from aiohttp import WSMsgType diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index 92a65ce..a261d57 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -81,13 +81,14 @@ def on_start(self, connection_context, op_id, params): connection_context.request_context, params) assert isinstance( execution_result, Observable), "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( + disposable = execution_result.subscribe(SubscriptionObserver( connection_context, op_id, self.send_execution_result, self.send_error, self.on_close )) + connection_context.register_operation(op_id, disposable) except Exception as e: self.send_error(connection_context, op_id, str(e)) @@ -97,7 +98,8 @@ def on_stop(self, connection_context, op_id): class SubscriptionObserver(Observer): - def __init__(self, connection_context, op_id, send_execution_result, send_error, on_close): + def __init__(self, connection_context, op_id, send_execution_result, + send_error, on_close): self.connection_context = connection_context self.op_id = op_id self.send_execution_result = send_execution_result diff --git a/graphql_ws/pubsub/__init__.py b/graphql_ws/pubsub/__init__.py new file mode 100644 index 0000000..8ec73ba --- /dev/null +++ b/graphql_ws/pubsub/__init__.py @@ -0,0 +1,6 @@ +import sys + +from .gevent_observable import GeventRxPubsub, GeventRxRedisPubsub + +if sys.version_info[0] > 2: + from .asyncio import AsyncioPubsub, AsyncioRedisPubsub diff --git a/graphql_ws/pubsub/asyncio.py b/graphql_ws/pubsub/asyncio.py new file mode 100644 index 0000000..408580b --- /dev/null +++ b/graphql_ws/pubsub/asyncio.py @@ -0,0 +1,76 @@ +import asyncio +import pickle + +import aredis + + +class AsyncioPubsub: + + def __init__(self): + self.subscriptions = {} + self.sub_id = 0 + + async def publish(self, channel, payload): + if channel in self.subscriptions: + for q in self.subscriptions[channel].values(): + await q.put(payload) + + def subscribe_to_channel(self, channel): + self.sub_id += 1 + q = asyncio.Queue() + if channel in self.subscriptions: + self.subscriptions[channel][self.sub_id] = q + else: + self.subscriptions[channel] = {self.sub_id: q} + return self.sub_id, q + + def unsubscribe(self, channel, sub_id): + if sub_id in self.subscriptions.get(channel, {}): + del self.subscriptions[channel][sub_id] + if not self.subscriptions[channel]: + del self.subscriptions[channel] + + +class AsyncioRedisPubsub: + + def __init__(self, host='localhost', port=6379, *args, **kwargs): + self.redis = aredis.StrictRedis(host, port, *args, **kwargs) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + self.subscriptions = {} + self.sub_id = 0 + self.task = None + + async def publish(self, channel, payload): + await self.redis.publish(channel, pickle.dumps(payload)) + + async def subscribe_to_channel(self, channel): + self.sub_id += 1 + q = asyncio.Queue() + if channel in self.subscriptions: + self.subscriptions[channel][self.sub_id] = q + else: + await self.pubsub.subscribe(channel) + self.subscriptions[channel] = {self.sub_id: q} + if not self.task: + self.task = asyncio.ensure_future( + self._wait_and_get_messages()) + return self.sub_id, q + + async def unsubscribe(self, channel, sub_id): + if sub_id in self.subscriptions.get(channel, {}): + del self.subscriptions[channel][sub_id] + if not self.subscriptions[channel]: + await self.pubsub.unsubscribe(channel) + del self.subscriptions[channel] + if not self.subscriptions: + self.task.cancel() + + async def _wait_and_get_messages(self): + while True: + msg = await self.pubsub.get_message() + if msg: + channel = msg['channel'].decode() + if channel in self.subscriptions: + for q in self.subscriptions[channel].values(): + await q.put(pickle.loads(msg['data'])) + await asyncio.sleep(.001) diff --git a/graphql_ws/pubsub/gevent_observable.py b/graphql_ws/pubsub/gevent_observable.py new file mode 100644 index 0000000..543bc28 --- /dev/null +++ b/graphql_ws/pubsub/gevent_observable.py @@ -0,0 +1,100 @@ +import pickle + +import gevent +import redis + +from rx.subjects import Subject +from rx import config + + +class SubjectObserversWrapper(object): + def __init__(self, pubsub, channel): + self.pubsub = pubsub + self.channel = channel + self.observers = [] + + self.lock = config["concurrency"].RLock() + + def __getitem__(self, key): + return self.observers[key] + + def __getattr__(self, attr): + return getattr(self.observers, attr) + + def remove(self, observer): + with self.lock: + self.observers.remove(observer) + if not self.observers: + self.pubsub.unsubscribe(self.channel) + + +class GeventRxPubsub(object): + + def __init__(self): + self.subscriptions = {} + + def publish(self, channel, payload): + if channel in self.subscriptions: + self.subscriptions[channel].on_next(payload) + + def subscribe_to_channel(self, channel): + if channel in self.subscriptions: + return self.subscriptions[channel] + else: + subject = Subject() + # monkeypatch Subject to unsubscribe pubsub on observable + # subscription.dispose() + subject.observers = SubjectObserversWrapper(self, channel) + self.subscriptions[channel] = subject + return subject + + def unsubscribe(self, channel): + if channel in self.subscriptions: + del self.subscriptions[channel] + + +class GeventRxRedisPubsub(object): + + def __init__(self, host='localhost', port=6379, *args, **kwargs): + redis.connection.socket = gevent.socket + self.redis = redis.StrictRedis(host, port, *args, **kwargs) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + self.subscriptions = {} + self.greenlet = None + + def publish(self, channel, payload): + self.redis.publish(channel, pickle.dumps(payload)) + + def subscribe_to_channel(self, channel): + if channel in self.subscriptions: + return self.subscriptions[channel] + else: + self.pubsub.subscribe(channel) + subject = Subject() + # monkeypatch Subject to unsubscribe pubsub on observable + # subscription.dispose() + subject.observers = SubjectObserversWrapper(self, channel) + self.subscriptions[channel] = subject + if not self.greenlet: + self.greenlet = gevent.spawn(self._wait_and_get_messages) + return subject + + def unsubscribe(self, channel): + if channel in self.subscriptions: + self.pubsub.unsubscribe(channel) + del self.subscriptions[channel] + if not self.subscriptions: + self.greenlet.kill() + + def _wait_and_get_messages(self): + while True: + msg = self.pubsub.get_message() + if msg: + if isinstance(msg['channel'], bytes): + channel = msg['channel'].decode() + else: + channel = msg['channel'] + if channel in self.subscriptions: + self.subscriptions[channel].on_next(pickle.loads( + msg['data'])) + gevent.sleep(.001) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index f41a1bb..20fc03b 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,4 +1,4 @@ -from inspect import isawaitable, isasyncgen +from inspect import isawaitable from asyncio import ensure_future from websockets import ConnectionClosed diff --git a/setup.py b/setup.py index 99844fc..67bbdc9 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,8 @@ """The setup script.""" +import sys + from setuptools import setup, find_packages with open('README.rst') as readme_file: @@ -15,9 +17,15 @@ requirements = [ 'graphql-core>=2.0<3', + 'gevent', + 'redis', + 'rx' # TODO: put package requirements here ] +if sys.version_info[0] > 2: + requirements.append('aredis') + setup_requirements = [ 'pytest-runner', # TODO(graphql-python): put setup requirements (distutils extensions,