Skip to content

Commit 5fa875f

Browse files
committed
Add group_send_multiple method to core layer
1 parent 4cb9b90 commit 5fa875f

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

channels_redis/core.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,13 @@ async def group_send(self, group, message):
538538
"""
539539
Sends a message to the entire group.
540540
"""
541+
await self.group_send_multiple(group, (message,))
542+
543+
async def group_send_multiple(self, group, messages):
544+
"""
545+
Sends a message to the entire group.
546+
"""
547+
541548
assert self.valid_group_name(group), "Group name not valid"
542549
# Retrieve list of all channel names
543550
key = self._group_key(group)
@@ -553,7 +560,7 @@ async def group_send(self, group, message):
553560
connection_to_channel_keys,
554561
channel_keys_to_message,
555562
channel_keys_to_capacity,
556-
) = self._map_channel_keys_to_connection(channel_names, message)
563+
) = self._map_channel_keys_to_connection(channel_names, messages)
557564

558565
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
559566
# Discard old messages based on expiry
@@ -565,17 +572,23 @@ async def group_send(self, group, message):
565572
await pipe.execute()
566573

567574
# Create a LUA script specific for this connection.
568-
# Make sure to use the message specific to this channel, it is
569-
# stored in channel_to_message dict and contains the
575+
# Make sure to use the message list specific to this channel, it is
576+
# stored in channel_to_message dict and each message contains the
570577
# __asgi_channel__ key.
571578

572579
group_send_lua = """
573580
local over_capacity = 0
581+
local num_messages = tonumber(ARGV[#ARGV - 2])
574582
local current_time = ARGV[#ARGV - 1]
575583
local expiry = ARGV[#ARGV]
576584
for i=1,#KEYS do
577-
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
578-
redis.call('ZADD', KEYS[i], current_time, ARGV[i])
585+
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
586+
local messages = {}
587+
for j=num_messages * (i - 1) + 1, num_messages * i do
588+
table.insert(messages, current_time)
589+
table.insert(messages, ARGV[j])
590+
end
591+
redis.call('ZADD', KEYS[i], unpack(messages))
579592
redis.call('EXPIRE', KEYS[i], expiry)
580593
else
581594
over_capacity = over_capacity + 1
@@ -585,18 +598,18 @@ async def group_send(self, group, message):
585598
"""
586599

587600
# We need to filter the messages to keep those related to the connection
588-
args = [
589-
channel_keys_to_message[channel_key]
590-
for channel_key in channel_redis_keys
591-
]
601+
args = []
602+
603+
for channel_key in channel_redis_keys:
604+
args += channel_keys_to_message[channel_key]
592605

593606
# We need to send the capacity for each channel
594607
args += [
595608
channel_keys_to_capacity[channel_key]
596609
for channel_key in channel_redis_keys
597610
]
598611

599-
args += [time.time(), self.expiry]
612+
args += [len(messages), time.time(), self.expiry]
600613

601614
# channel_keys does not contain a single redis key more than once
602615
connection = self.connection(connection_index)
@@ -611,7 +624,7 @@ async def group_send(self, group, message):
611624
group,
612625
)
613626

614-
def _map_channel_keys_to_connection(self, channel_names, message):
627+
def _map_channel_keys_to_connection(self, channel_names, messages):
615628
"""
616629
For a list of channel names, GET
617630
@@ -626,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
626639
# Connection dict keyed by index to list of redis keys mapped on that index
627640
connection_to_channel_keys = collections.defaultdict(list)
628641
# Message dict maps redis key to the message that needs to be send on that key
629-
channel_key_to_message = dict()
642+
channel_key_to_message = collections.defaultdict(list)
630643
# Channel key mapped to its capacity
631644
channel_key_to_capacity = dict()
632645

@@ -640,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
640653
# Have we come across the same redis key?
641654
if channel_key not in channel_key_to_message:
642655
# If not, fill the corresponding dicts
643-
message = dict(message.items())
644-
message["__asgi_channel__"] = [channel]
645-
channel_key_to_message[channel_key] = message
656+
for message in messages:
657+
message = dict(message.items())
658+
message["__asgi_channel__"] = [channel]
659+
channel_key_to_message[channel_key].append(message)
646660
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
647661
idx = self.consistent_hash(channel_non_local_name)
648662
connection_to_channel_keys[idx].append(channel_key)
649663
else:
650664
# Yes, Append the channel in message dict
651-
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
665+
for message in channel_key_to_message[channel_key]:
666+
message["__asgi_channel__"].append(channel)
652667

653668
# Now that we know what message needs to be send on a redis key we serialize it
654669
for key, value in channel_key_to_message.items():
655670
# Serialize the message stored for each redis key
656-
channel_key_to_message[key] = self.serialize(value)
671+
for idx, message in enumerate(value):
672+
channel_key_to_message[key][idx] = self.serialize(message)
657673

658674
return (
659675
connection_to_channel_keys,

tests/test_core.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import collections
23
import random
34

45
import async_timeout
@@ -244,6 +245,40 @@ async def test_groups_basic(channel_layer):
244245
await channel_layer.flush()
245246

246247

248+
@pytest.mark.asyncio
249+
async def test_groups_multiple(channel_layer):
250+
"""
251+
Tests basic group operation.
252+
"""
253+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
254+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
255+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
256+
await channel_layer.group_add("test-group", channel_name1)
257+
await channel_layer.group_add("test-group", channel_name2)
258+
await channel_layer.group_add("test-group", channel_name3)
259+
260+
messages = [
261+
{"type": "message.1"},
262+
{"type": "message.2"},
263+
{"type": "message.3"},
264+
]
265+
266+
expected = {msg["type"] for msg in messages}
267+
268+
await channel_layer.group_send_multiple("test-group", messages)
269+
270+
received = collections.defaultdict(set)
271+
272+
for channel_name in (channel_name1, channel_name2, channel_name3):
273+
async with async_timeout.timeout(1):
274+
for _ in range(len(messages)):
275+
received[channel_name].add(
276+
(await channel_layer.receive(channel_name))["type"]
277+
)
278+
279+
assert received[channel_name] == expected
280+
281+
247282
@pytest.mark.asyncio
248283
async def test_groups_channel_full(channel_layer):
249284
"""

0 commit comments

Comments
 (0)