Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-integer PKs #9

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ dist/
__pycache__/
.tox
test-results/

# Test DB
db.sqlite3
21 changes: 15 additions & 6 deletions rest_live/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Subscription:
# to keep track of all the instances that a given subscription currently considers
# visible. This set keeps track of that. This will probably be the main resource bottleneck
# in django-rest-live
pks_in_queryset: Set[int]
pks_in_queryset: Set[str]


class SubscriptionConsumer(JsonWebsocketConsumer):
Expand All @@ -41,6 +41,7 @@ class SubscriptionConsumer(JsonWebsocketConsumer):
"""

registry: Dict[str, Type[RealtimeMixin]] = dict()
subscriptions: List[Subscription]
public = True

def connect(self):
Expand Down Expand Up @@ -127,7 +128,8 @@ def receive_json(self, content: Dict[str, Any], **kwargs):
if view.action == "retrieve":
view.kwargs.setdefault(view.lookup_field, lookup_value)
try:
view.get_object()
instance = view.get_object()
pks_in_queryset = {str(instance.pk)}
except Http404:
self.send_error(
request_id,
Expand All @@ -137,6 +139,9 @@ def receive_json(self, content: Dict[str, Any], **kwargs):
return
except (NotAuthenticated, PermissionDenied):
has_permission = False
else:
qs = view.filter_queryset(view.get_queryset())
pks_in_queryset = {str(pk) for pk in qs.values_list("pk", flat=True)}

if not has_permission:
self.send_error(
Expand All @@ -156,9 +161,7 @@ def receive_json(self, content: Dict[str, Any], **kwargs):
action=view_action,
view_kwargs=view_kwargs,
query_params=query_params,
pks_in_queryset=set(
[inst["pk"] for inst in view.get_queryset().all().values("pk")]
),
pks_in_queryset=pks_in_queryset
)
)

Expand Down Expand Up @@ -206,12 +209,18 @@ def receive_json(self, content: Dict[str, Any], **kwargs):

def model_saved(self, event):
channel_name: str = event["channel_name"]
instance_pk: int = event["instance_pk"]
instance_pk: str = event["instance_pk"]
model_label: str = event["model"]

viewset_class = self.registry[model_label]

for subscription in self.subscriptions[channel_name]:
if (
subscription.action == "retrieve"
and instance_pk not in subscription.pks_in_queryset
):
continue

view = viewset_class.from_scope(
subscription.action,
self.scope,
Expand Down
2 changes: 1 addition & 1 deletion rest_live/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def save_handler(sender, instance, *args, **kwargs):
{
"type": "model.saved",
"model": model_label,
"instance_pk": instance.pk,
"instance_pk": str(instance.pk),
"channel_name": group_name,
},
)
21 changes: 21 additions & 0 deletions test_app/migrations/0002_uuidtodo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 3.1.7 on 2021-02-24 19:32

from django.db import migrations, models
import uuid


class Migration(migrations.Migration):

dependencies = [
('test_app', '0001_initial'),
]

operations = [
migrations.CreateModel(
name='UUIDTodo',
fields=[
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('included', models.BooleanField(default=True)),
],
),
]
6 changes: 6 additions & 0 deletions test_app/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.contrib import admin
from django.db import models
from uuid import uuid4


class List(models.Model):
Expand All @@ -13,5 +14,10 @@ class Todo(models.Model):
another_field = models.BooleanField(default=True)


class UUIDTodo(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4)
included = models.BooleanField(default=True)


admin.site.register(List)
admin.site.register(Todo)
9 changes: 7 additions & 2 deletions test_app/serializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rest_framework import serializers
from rest_framework.permissions import BasePermission

from test_app.models import Todo
from test_app.models import Todo, UUIDTodo


class TodoSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -30,3 +29,9 @@ class Meta:

def get_auth(self, obj):
return "ADMIN"


class UUIDTodoSerializer(serializers.ModelSerializer):
class Meta:
model = UUIDTodo
fields = ["id"]
8 changes: 7 additions & 1 deletion test_app/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
)

from rest_live.mixins import RealtimeMixin
from test_app.models import Todo
from test_app.models import Todo, UUIDTodo

from test_app.serializers import (
TodoSerializer,
AuthedTodoSerializer,
KwargsTodoSerializer,
UUIDTodoSerializer,
)


Expand Down Expand Up @@ -57,3 +58,8 @@ class KwargViewSet(GenericAPIView, RealtimeMixin):
class FilteredViewSet(GenericAPIView, RealtimeMixin):
queryset = Todo.objects.filter(text="special")
serializer_class = TodoSerializer


class UUIDTodoViewSet(GenericAPIView, RealtimeMixin):
queryset = UUIDTodo.objects.filter(included=True)
serializer_class = UUIDTodoSerializer
120 changes: 106 additions & 14 deletions tests/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from rest_live.routers import RealtimeRouter
from rest_live.testing import async_test, get_headers_for_user

from test_app.models import List, Todo
from test_app.models import List, Todo, UUIDTodo
from test_app.serializers import AuthedTodoSerializer, TodoSerializer
from test_app.views import (
TodoViewSet,
AuthedTodoViewSet,
ConditionalTodoViewSet,
KwargViewSet,
FilteredViewSet,
UUIDTodoViewSet,
)
from tests.utils import RestLiveTestCase

Expand All @@ -45,12 +46,19 @@ async def asyncTearDown(self):

@async_test
async def test_single_update(self):
self.todo = await db(Todo.objects.create)(list=self.list, text="test")
self.todo = await self.make_todo("test")
other_todo = await self.make_todo("another")
req = await self.subscribe_to_todo()

self.todo.text = "MODIFIED"
await db(self.todo.save)()
await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, req)

# Modifying other instances shouldn't trigger a message
other_todo.text = "MODIFIED TOO"
await db(other_todo.save)()
assert await self.client.receive_nothing()

@async_test
async def test_list_unsubscribe(self):
self.todo = await self.make_todo()
Expand Down Expand Up @@ -311,18 +319,6 @@ async def test_list(self):
await self.make_todo("no match")
self.assertTrue(await self.client.receive_nothing())

@async_test
async def test_retrieve(self):
self.todo = await self.make_todo("hello world")
request_id = await self.subscribe_to_todo(params={"search": "hello"})

await db(self.todo.save)()
await self.assertReceivedBroadcastForTodo(self.todo, UPDATED, request_id)

self.todo.text = "goodbye world" # No longer matches the query
await db(self.todo.save)()
await self.assertReceivedBroadcastForTodo(self.todo, DELETED, request_id)


class QuerysetFetchTest(RestLiveTestCase):
"""
Expand Down Expand Up @@ -570,3 +566,99 @@ async def test_subscribe_to_unknown_model(self):
{"type": "subscribe", "id": 1337, "model": "blah.Model", "value": 1}
)
await self.assertReceiveError(1337, 404)


class UUIDTodoTests(RestLiveTestCase):
"""
Tests compatibility with models that use non-integer IDs
"""

async def asyncSetUp(self):
router = RealtimeRouter()
router.register(UUIDTodoViewSet)
self.client = APICommunicator(router.as_consumer(), "/ws/subscribe/")
connected, _ = await self.client.connect()
self.assertTrue(connected)

async def asyncTearDown(self):
await self.client.disconnect()

@async_test
async def test_list(self):
label = UUIDTodo._meta.label
request_id = await self.subscribe(label, "list")
self.assertTrue(await self.client.receive_nothing())

todo = await db(UUIDTodo.objects.create)()
response = await self.client.receive_json_from()
self.assertDictEqual(
{
"type": "broadcast",
"id": request_id,
"model": label,
"action": CREATED,
"instance": {"id": str(todo.id)},
},
response,
)

await db(todo.save)()
response = await self.client.receive_json_from()
self.assertDictEqual(
{
"type": "broadcast",
"id": request_id,
"model": label,
"action": UPDATED,
"instance": {"id": str(todo.id)},
},
response,
)

todo.included = False
await db(todo.save)()
response = await self.client.receive_json_from()
self.assertDictEqual(
{
"type": "broadcast",
"id": request_id,
"model": label,
"action": DELETED,
"instance": {"id": str(todo.id), "pk": str(todo.id)},
},
response,
)

@async_test
async def test_retrieve(self):
label = UUIDTodo._meta.label
todo = await db(UUIDTodo.objects.create)()
request_id = await self.subscribe(label, "retrieve", str(todo.id))
self.assertTrue(await self.client.receive_nothing())

await db(todo.save)()
response = await self.client.receive_json_from()
self.assertDictEqual(
{
"type": "broadcast",
"id": request_id,
"model": label,
"action": UPDATED,
"instance": {"id": str(todo.id)},
},
response,
)

todo.included = False
await db(todo.save)()
response = await self.client.receive_json_from()
self.assertDictEqual(
{
"type": "broadcast",
"id": request_id,
"model": label,
"action": DELETED,
"instance": {"id": str(todo.id), "pk": str(todo.id)},
},
response,
)