From a33a38c6a8b1062dc9d4bf37e517b118f11b2ea7 Mon Sep 17 00:00:00 2001 From: Ed Rivas Date: Wed, 24 Feb 2021 13:46:36 -0600 Subject: [PATCH 1/2] Add support for non-integer PKs --- .gitignore | 3 + rest_live/consumers.py | 8 ++- rest_live/signals.py | 2 +- test_app/migrations/0002_uuidtodo.py | 21 ++++++ test_app/models.py | 6 ++ test_app/serializers.py | 9 ++- test_app/views.py | 9 ++- tests/test_live.py | 99 +++++++++++++++++++++++++++- 8 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 test_app/migrations/0002_uuidtodo.py diff --git a/.gitignore b/.gitignore index 6bbdc4f..94df2ad 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ dist/ __pycache__/ .tox test-results/ + +# Test DB +db.sqlite3 diff --git a/rest_live/consumers.py b/rest_live/consumers.py index 9e371e4..780b3be 100644 --- a/rest_live/consumers.py +++ b/rest_live/consumers.py @@ -26,7 +26,7 @@ class Subscription: # to keep track of all the instances that a given subscription currently considers # visible. This set keeps track of that. This will probably be the main resource bottleneck # in django-rest-live - pks_in_queryset: Set[int] + pks_in_queryset: Set[str] class SubscriptionConsumer(JsonWebsocketConsumer): @@ -159,7 +159,9 @@ def receive_json(self, content: Dict[str, Any], **kwargs): view_kwargs=view_kwargs, query_params=query_params, pks_in_queryset=set( - [inst["pk"] for inst in view.get_queryset().all().values("pk")] + str(pk) for pk in view.filter_queryset( + view.get_queryset() + ).values_list("pk", flat=True) ), ) ) @@ -208,7 +210,7 @@ def receive_json(self, content: Dict[str, Any], **kwargs): def model_saved(self, event): channel_name: str = event["channel_name"] - instance_pk: int = event["instance_pk"] + instance_pk: str = event["instance_pk"] model_label: str = event["model"] viewset_class = self.registry[model_label] diff --git a/rest_live/signals.py b/rest_live/signals.py index 693800b..824b107 100644 --- a/rest_live/signals.py +++ b/rest_live/signals.py @@ -13,7 +13,7 @@ def save_handler(sender, instance, *args, **kwargs): { "type": "model.saved", "model": model_label, - "instance_pk": instance.pk, + "instance_pk": str(instance.pk), "channel_name": group_name, }, ) diff --git a/test_app/migrations/0002_uuidtodo.py b/test_app/migrations/0002_uuidtodo.py new file mode 100644 index 0000000..a8112fb --- /dev/null +++ b/test_app/migrations/0002_uuidtodo.py @@ -0,0 +1,21 @@ +# Generated by Django 3.1.7 on 2021-02-24 19:32 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('test_app', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='UUIDTodo', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('included', models.BooleanField(default=True)), + ], + ), + ] diff --git a/test_app/models.py b/test_app/models.py index f39266d..bef62ef 100644 --- a/test_app/models.py +++ b/test_app/models.py @@ -1,5 +1,6 @@ from django.contrib import admin from django.db import models +from uuid import uuid4 class List(models.Model): @@ -13,5 +14,10 @@ class Todo(models.Model): another_field = models.BooleanField(default=True) +class UUIDTodo(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4) + included = models.BooleanField(default=True) + + admin.site.register(List) admin.site.register(Todo) diff --git a/test_app/serializers.py b/test_app/serializers.py index 2288cd0..5489c3b 100644 --- a/test_app/serializers.py +++ b/test_app/serializers.py @@ -1,7 +1,6 @@ from rest_framework import serializers -from rest_framework.permissions import BasePermission -from test_app.models import Todo +from test_app.models import Todo, UUIDTodo class TodoSerializer(serializers.ModelSerializer): @@ -30,3 +29,9 @@ class Meta: def get_auth(self, obj): return "ADMIN" + + +class UUIDTodoSerializer(serializers.ModelSerializer): + class Meta: + model = UUIDTodo + fields = ["id"] diff --git a/test_app/views.py b/test_app/views.py index 6942f68..05e6acb 100644 --- a/test_app/views.py +++ b/test_app/views.py @@ -2,17 +2,17 @@ from rest_framework.generics import GenericAPIView from rest_framework.permissions import ( IsAuthenticated, - DjangoModelPermissions, BasePermission, ) from rest_live.mixins import RealtimeMixin -from test_app.models import Todo +from test_app.models import Todo, UUIDTodo from test_app.serializers import ( TodoSerializer, AuthedTodoSerializer, KwargsTodoSerializer, + UUIDTodoSerializer, ) @@ -56,3 +56,8 @@ class KwargViewSet(GenericAPIView, RealtimeMixin): class FilteredViewSet(GenericAPIView, RealtimeMixin): queryset = Todo.objects.filter(text="special") serializer_class = TodoSerializer + + +class UUIDTodoViewSet(GenericAPIView, RealtimeMixin): + queryset = UUIDTodo.objects.filter(included=True) + serializer_class = UUIDTodoSerializer diff --git a/tests/test_live.py b/tests/test_live.py index 3f07387..8d35a16 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -11,7 +11,7 @@ from rest_live.routers import RealtimeRouter from rest_live.testing import async_test, get_headers_for_user -from test_app.models import List, Todo +from test_app.models import List, Todo, UUIDTodo from test_app.serializers import AuthedTodoSerializer, TodoSerializer from test_app.views import ( TodoViewSet, @@ -19,6 +19,7 @@ ConditionalTodoViewSet, KwargViewSet, FilteredViewSet, + UUIDTodoViewSet, ) from tests.utils import RestLiveTestCase @@ -522,3 +523,99 @@ async def test_subscribe_to_unknown_model(self): {"type": "subscribe", "id": 1337, "model": "blah.Model", "value": 1} ) await self.assertReceiveError(1337, 404) + + +class UUIDTodoTests(RestLiveTestCase): + """ + Tests compatibility with models that use non-integer IDs + """ + + async def asyncSetUp(self): + router = RealtimeRouter() + router.register(UUIDTodoViewSet) + self.client = APICommunicator(router.as_consumer(), "/ws/subscribe/") + connected, _ = await self.client.connect() + self.assertTrue(connected) + + async def asyncTearDown(self): + await self.client.disconnect() + + @async_test + async def test_list(self): + label = UUIDTodo._meta.label + request_id = await self.subscribe(label, "list") + self.assertTrue(await self.client.receive_nothing()) + + todo = await db(UUIDTodo.objects.create)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": CREATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": UPDATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + todo.included = False + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": DELETED, + "instance": {"id": str(todo.id), "pk": str(todo.id)}, + }, + response, + ) + + @async_test + async def test_retrieve(self): + label = UUIDTodo._meta.label + todo = await db(UUIDTodo.objects.create)() + request_id = await self.subscribe(label, "retrieve", str(todo.id)) + self.assertTrue(await self.client.receive_nothing()) + + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": UPDATED, + "instance": {"id": str(todo.id)}, + }, + response, + ) + + todo.included = False + await db(todo.save)() + response = await self.client.receive_json_from() + self.assertDictEqual( + { + "type": "broadcast", + "id": request_id, + "model": label, + "action": DELETED, + "instance": {"id": str(todo.id), "pk": str(todo.id)}, + }, + response, + ) From 315052026538d2b1b739628bd5a22d9739be5836 Mon Sep 17 00:00:00 2001 From: Ed Rivas Date: Wed, 10 Mar 2021 12:05:09 -0600 Subject: [PATCH 2/2] Limit messages to single instance for "retrieve" subscriptions Previously modifying other instances in the queryset would trigger messages for all "retrieve" subscriptions --- rest_live/consumers.py | 19 +++++++++++++------ tests/test_live.py | 21 ++++++++------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/rest_live/consumers.py b/rest_live/consumers.py index ef2e6ac..670be87 100644 --- a/rest_live/consumers.py +++ b/rest_live/consumers.py @@ -41,6 +41,7 @@ class SubscriptionConsumer(JsonWebsocketConsumer): """ registry: Dict[str, Type[RealtimeMixin]] = dict() + subscriptions: List[Subscription] public = True def connect(self): @@ -127,7 +128,8 @@ def receive_json(self, content: Dict[str, Any], **kwargs): if view.action == "retrieve": view.kwargs.setdefault(view.lookup_field, lookup_value) try: - view.get_object() + instance = view.get_object() + pks_in_queryset = {str(instance.pk)} except Http404: self.send_error( request_id, @@ -137,6 +139,9 @@ def receive_json(self, content: Dict[str, Any], **kwargs): return except (NotAuthenticated, PermissionDenied): has_permission = False + else: + qs = view.filter_queryset(view.get_queryset()) + pks_in_queryset = {str(pk) for pk in qs.values_list("pk", flat=True)} if not has_permission: self.send_error( @@ -156,11 +161,7 @@ def receive_json(self, content: Dict[str, Any], **kwargs): action=view_action, view_kwargs=view_kwargs, query_params=query_params, - pks_in_queryset=set( - str(pk) for pk in view.filter_queryset( - view.get_queryset() - ).values_list("pk", flat=True) - ), + pks_in_queryset=pks_in_queryset ) ) @@ -214,6 +215,12 @@ def model_saved(self, event): viewset_class = self.registry[model_label] for subscription in self.subscriptions[channel_name]: + if ( + subscription.action == "retrieve" + and instance_pk not in subscription.pks_in_queryset + ): + continue + view = viewset_class.from_scope( subscription.action, self.scope, diff --git a/tests/test_live.py b/tests/test_live.py index 54b58eb..90e54f2 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -46,12 +46,19 @@ async def asyncTearDown(self): @async_test async def test_single_update(self): - self.todo = await db(Todo.objects.create)(list=self.list, text="test") + self.todo = await self.make_todo("test") + other_todo = await self.make_todo("another") req = await self.subscribe_to_todo() + self.todo.text = "MODIFIED" await db(self.todo.save)() await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, req) + # Modifying other instances shouldn't trigger a message + other_todo.text = "MODIFIED TOO" + await db(other_todo.save)() + assert await self.client.receive_nothing() + @async_test async def test_list_unsubscribe(self): self.todo = await self.make_todo() @@ -312,18 +319,6 @@ async def test_list(self): await self.make_todo("no match") self.assertTrue(await self.client.receive_nothing()) - @async_test - async def test_retrieve(self): - self.todo = await self.make_todo("hello world") - request_id = await self.subscribe_to_todo(params={"search": "hello"}) - - await db(self.todo.save)() - await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, request_id) - - self.todo.text = "goodbye world" # No longer matches the query - await db(self.todo.save)() - await self.assertReceivedBroadcastForTodo(self.todo, DELETED, request_id) - class QuerysetFetchTest(RestLiveTestCase): """