1
1
import asyncio
2
2
import typing as t
3
3
4
- import aioredis
4
+ from redis import asyncio as aioredis
5
+ from redis .exceptions import ConnectionError
5
6
from asgiref .sync import sync_to_async
6
7
from django .db .models import QuerySet
7
8
from fastapi import APIRouter , Depends , WebSocket , WebSocketDisconnect , status
@@ -51,7 +52,7 @@ async def get_ticket(
51
52
uid = nacl .encoding .URLSafeBase64Encoder .encode (nacl .utils .random (32 ))
52
53
ticket_model = TicketInner (user = user .id , req = ticket_request )
53
54
ticket_raw = msgpack_encode (ticket_model .dict ())
54
- await redisw .redis .set (uid , ticket_raw , expire = TICKET_VALIDITY_SECONDS * 1000 )
55
+ await redisw .redis .set (uid , ticket_raw , ex = TICKET_VALIDITY_SECONDS * 1000 )
55
56
return TicketOut (ticket = uid )
56
57
57
58
@@ -103,9 +104,9 @@ async def send_item_updates(
103
104
104
105
async def redis_connector (websocket : WebSocket , ticket_model : TicketInner , user : UserType , stoken : t .Optional [str ]):
105
106
async def producer_handler (r : aioredis .Redis , ws : WebSocket ):
107
+ pubsub = r .pubsub ()
106
108
channel_name = f"col.{ ticket_model .req .collection } "
107
- (channel ,) = await r .psubscribe (channel_name )
108
- assert isinstance (channel , aioredis .Channel )
109
+ await pubsub .subscribe (channel_name )
109
110
110
111
# Send missing items if we are not up to date
111
112
queryset : QuerySet [models .Collection ] = get_collection_queryset (user )
@@ -117,12 +118,20 @@ async def producer_handler(r: aioredis.Redis, ws: WebSocket):
117
118
return
118
119
await send_item_updates (websocket , collection , user , stoken )
119
120
121
+ async def handle_message ():
122
+ msg = await pubsub .get_message (ignore_subscribe_messages = True , timeout = 20 )
123
+ message_raw = t .cast (t .Optional [t .Tuple [str , bytes ]], msg )
124
+ if message_raw :
125
+ _ , message = message_raw
126
+ await ws .send_bytes (message )
127
+
120
128
try :
121
129
while True :
122
130
# We wait on the websocket so we fail if web sockets fail or get data
123
131
receive = asyncio .create_task (websocket .receive ())
124
132
done , pending = await asyncio .wait (
125
- {receive , channel .wait_message ()}, return_when = asyncio .FIRST_COMPLETED
133
+ {receive , handle_message ()},
134
+ return_when = asyncio .FIRST_COMPLETED ,
126
135
)
127
136
for task in pending :
128
137
task .cancel ()
@@ -131,12 +140,7 @@ async def producer_handler(r: aioredis.Redis, ws: WebSocket):
131
140
await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
132
141
return
133
142
134
- message_raw = t .cast (t .Optional [t .Tuple [str , bytes ]], await channel .get ())
135
- if message_raw :
136
- _ , message = message_raw
137
- await ws .send_bytes (message )
138
-
139
- except aioredis .errors .ConnectionClosedError :
143
+ except ConnectionError :
140
144
await websocket .close (code = status .WS_1012_SERVICE_RESTART )
141
145
except WebSocketDisconnect :
142
146
pass
0 commit comments