diff --git a/.gitignore b/.gitignore index 6bbdc4f..94df2ad 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ dist/ __pycache__/ .tox test-results/ + +# Test DB +db.sqlite3 diff --git a/rest_live/consumers.py b/rest_live/consumers.py index e325529..670be87 100644 --- a/rest_live/consumers.py +++ b/rest_live/consumers.py @@ -28,7 +28,7 @@ class Subscription: # to keep track of all the instances that a given subscription currently considers # visible. This set keeps track of that. This will probably be the main resource bottleneck # in django-rest-live - pks_in_queryset: Set[int] + pks_in_queryset: Set[str] class SubscriptionConsumer(JsonWebsocketConsumer): @@ -41,6 +41,7 @@ class SubscriptionConsumer(JsonWebsocketConsumer): """ registry: Dict[str, Type[RealtimeMixin]] = dict() + subscriptions: List[Subscription] public = True def connect(self): @@ -127,7 +128,8 @@ def receive_json(self, content: Dict[str, Any], **kwargs): if view.action == "retrieve": view.kwargs.setdefault(view.lookup_field, lookup_value) try: - view.get_object() + instance = view.get_object() + pks_in_queryset = {str(instance.pk)} except Http404: self.send_error( request_id, @@ -137,6 +139,9 @@ def receive_json(self, content: Dict[str, Any], **kwargs): return except (NotAuthenticated, PermissionDenied): has_permission = False + else: + qs = view.filter_queryset(view.get_queryset()) + pks_in_queryset = {str(pk) for pk in qs.values_list("pk", flat=True)} if not has_permission: self.send_error( @@ -156,9 +161,7 @@ def receive_json(self, content: Dict[str, Any], **kwargs): action=view_action, view_kwargs=view_kwargs, query_params=query_params, - pks_in_queryset=set( - [inst["pk"] for inst in view.get_queryset().all().values("pk")] - ), + pks_in_queryset=pks_in_queryset ) ) @@ -206,12 +209,18 @@ def receive_json(self, content: Dict[str, Any], **kwargs): def model_saved(self, event): channel_name: str = event["channel_name"] - instance_pk: int = event["instance_pk"] + instance_pk: str = event["instance_pk"] model_label: str = event["model"] viewset_class = self.registry[model_label] for subscription in self.subscriptions[channel_name]: + if ( + subscription.action == "retrieve" + and instance_pk not in subscription.pks_in_queryset + ): + continue + view = viewset_class.from_scope( subscription.action, self.scope, diff --git a/rest_live/signals.py b/rest_live/signals.py index 693800b..824b107 100644 --- a/rest_live/signals.py +++ b/rest_live/signals.py @@ -13,7 +13,7 @@ def save_handler(sender, instance, *args, **kwargs): { "type": "model.saved", "model": model_label, - "instance_pk": instance.pk, + "instance_pk": str(instance.pk), "channel_name": group_name, }, ) diff --git a/test_app/migrations/0002_uuidtodo.py b/test_app/migrations/0002_uuidtodo.py new file mode 100644 index 0000000..a8112fb --- /dev/null +++ b/test_app/migrations/0002_uuidtodo.py @@ -0,0 +1,21 @@ +# Generated by Django 3.1.7 on 2021-02-24 19:32 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('test_app', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='UUIDTodo', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('included', models.BooleanField(default=True)), + ], + ), + ] diff --git a/test_app/models.py b/test_app/models.py index f39266d..bef62ef 100644 --- a/test_app/models.py +++ b/test_app/models.py @@ -1,5 +1,6 @@ from django.contrib import admin from django.db import models +from uuid import uuid4 class List(models.Model): @@ -13,5 +14,10 @@ class Todo(models.Model): another_field = models.BooleanField(default=True) +class UUIDTodo(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4) + included = models.BooleanField(default=True) + + admin.site.register(List) admin.site.register(Todo) diff --git a/test_app/serializers.py b/test_app/serializers.py index 2288cd0..5489c3b 100644 --- a/test_app/serializers.py +++ b/test_app/serializers.py @@ -1,7 +1,6 @@ from rest_framework import serializers -from rest_framework.permissions import BasePermission -from test_app.models import Todo +from test_app.models import Todo, UUIDTodo class TodoSerializer(serializers.ModelSerializer): @@ -30,3 +29,9 @@ class Meta: def get_auth(self, obj): return "ADMIN" + + +class UUIDTodoSerializer(serializers.ModelSerializer): + class Meta: + model = UUIDTodo + fields = ["id"] diff --git a/test_app/views.py b/test_app/views.py index c762e97..5c727e6 100644 --- a/test_app/views.py +++ b/test_app/views.py @@ -6,12 +6,13 @@ ) from rest_live.mixins import RealtimeMixin -from test_app.models import Todo +from test_app.models import Todo, UUIDTodo from test_app.serializers import ( TodoSerializer, AuthedTodoSerializer, KwargsTodoSerializer, + UUIDTodoSerializer, ) @@ -57,3 +58,8 @@ class KwargViewSet(GenericAPIView, RealtimeMixin): class FilteredViewSet(GenericAPIView, RealtimeMixin): queryset = Todo.objects.filter(text="special") serializer_class = TodoSerializer + + +class UUIDTodoViewSet(GenericAPIView, RealtimeMixin): + queryset = UUIDTodo.objects.filter(included=True) + serializer_class = UUIDTodoSerializer diff --git a/tests/test_live.py b/tests/test_live.py index 7aa574a..90e54f2 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -11,7 +11,7 @@ from rest_live.routers import RealtimeRouter from rest_live.testing import async_test, get_headers_for_user -from test_app.models import List, Todo +from test_app.models import List, Todo, UUIDTodo from test_app.serializers import AuthedTodoSerializer, TodoSerializer from test_app.views import ( TodoViewSet, @@ -19,6 +19,7 @@ ConditionalTodoViewSet, KwargViewSet, FilteredViewSet, + UUIDTodoViewSet, ) from tests.utils import RestLiveTestCase @@ -45,12 +46,19 @@ async def asyncTearDown(self): @async_test async def test_single_update(self): - self.todo = await db(Todo.objects.create)(list=self.list, text="test") + self.todo = await self.make_todo("test") + other_todo = await self.make_todo("another") req = await self.subscribe_to_todo() + self.todo.text = "MODIFIED" await db(self.todo.save)() await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, req) + # Modifying other instances shouldn't trigger a message + other_todo.text = "MODIFIED TOO" + await db(other_todo.save)() + assert await self.client.receive_nothing() + @async_test async def test_list_unsubscribe(self): self.todo = await self.make_todo() @@ -311,18 +319,6 @@ async def test_list(self): await self.make_todo("no match") self.assertTrue(await self.client.receive_nothing()) - @async_test - async def test_retrieve(self): - self.todo = await self.make_todo("hello world") - request_id = await self.subscribe_to_todo(params={"search": "hello"}) - - await db(self.todo.save)() - await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, request_id) - - self.todo.text = "goodbye world" # No longer matches the query - await db(self.todo.save)() - await self.assertReceivedBroadcastForTodo(self.todo, DELETED, request_id) - class QuerysetFetchTest(RestLiveTestCase): """ @@ -570,3 +566,99 @@ async def test_subscribe_to_unknown_model(self): {"type": "subscribe", "id": 1337, "model": "blah.Model", "value": 1} ) await self.assertReceiveError(1337, 404) + + +class UUIDTodoTests(RestLiveTestCase): + """ + Tests compatibility with models that use non-integer IDs + """ + + async def asyncSetUp(self): + router = RealtimeRouter() + router.register(UUIDTodoViewSet) + self.client = APICommunicator(router.as_consumer(), "/ws/subscribe/") + connected, _ = await self.client.connect() + self.assertTrue(connected) + + async def asyncTearDown(self): + await self.client.disconnect() + + @async_test + async def test_list(self): + label = UUIDTodo._meta.label + request_id = await self.subscribe(label, "list") + self.assertTrue(await self.client.receive_nothing()) + + todo = await db(UUIDTodo.objects.create)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": CREATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": UPDATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + todo.included = False + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": DELETED, + "instance": {"id": str(todo.id), "pk": str(todo.id)}, + }, + response, + ) + + @async_test + async def test_retrieve(self): + label = UUIDTodo._meta.label + todo = await db(UUIDTodo.objects.create)() + request_id = await self.subscribe(label, "retrieve", str(todo.id)) + self.assertTrue(await self.client.receive_nothing()) + + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": UPDATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + todo.included = False + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": DELETED, + "instance": {"id": str(todo.id), "pk": str(todo.id)}, + }, + response, + )