@@ -538,6 +538,13 @@ async def group_send(self, group, message):
538
538
"""
539
539
Sends a message to the entire group.
540
540
"""
541
+ await self .group_send_multiple (group , (message ,))
542
+
543
+ async def group_send_multiple (self , group , messages ):
544
+ """
545
+ Sends a message to the entire group.
546
+ """
547
+
541
548
assert self .valid_group_name (group ), "Group name not valid"
542
549
# Retrieve list of all channel names
543
550
key = self ._group_key (group )
@@ -553,7 +560,7 @@ async def group_send(self, group, message):
553
560
connection_to_channel_keys ,
554
561
channel_keys_to_message ,
555
562
channel_keys_to_capacity ,
556
- ) = self ._map_channel_keys_to_connection (channel_names , message )
563
+ ) = self ._map_channel_keys_to_connection (channel_names , messages )
557
564
558
565
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
559
566
# Discard old messages based on expiry
@@ -565,17 +572,23 @@ async def group_send(self, group, message):
565
572
await pipe .execute ()
566
573
567
574
# Create a LUA script specific for this connection.
568
- # Make sure to use the message specific to this channel, it is
569
- # stored in channel_to_message dict and contains the
575
+ # Make sure to use the message list specific to this channel, it is
576
+ # stored in channel_to_message dict and each message contains the
570
577
# __asgi_channel__ key.
571
578
572
579
group_send_lua = """
573
580
local over_capacity = 0
581
+ local num_messages = tonumber(ARGV[#ARGV - 2])
574
582
local current_time = ARGV[#ARGV - 1]
575
583
local expiry = ARGV[#ARGV]
576
584
for i=1,#KEYS do
577
- if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
578
- redis.call('ZADD', KEYS[i], current_time, ARGV[i])
585
+ if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
586
+ local messages = {}
587
+ for j=num_messages * (i - 1) + 1, num_messages * i do
588
+ table.insert(messages, current_time)
589
+ table.insert(messages, ARGV[j])
590
+ end
591
+ redis.call('ZADD', KEYS[i], unpack(messages))
579
592
redis.call('EXPIRE', KEYS[i], expiry)
580
593
else
581
594
over_capacity = over_capacity + 1
@@ -585,18 +598,18 @@ async def group_send(self, group, message):
585
598
"""
586
599
587
600
# We need to filter the messages to keep those related to the connection
588
- args = [
589
- channel_keys_to_message [ channel_key ]
590
- for channel_key in channel_redis_keys
591
- ]
601
+ args = []
602
+
603
+ for channel_key in channel_redis_keys :
604
+ args += channel_keys_to_message [ channel_key ]
592
605
593
606
# We need to send the capacity for each channel
594
607
args += [
595
608
channel_keys_to_capacity [channel_key ]
596
609
for channel_key in channel_redis_keys
597
610
]
598
611
599
- args += [time .time (), self .expiry ]
612
+ args += [len ( messages ), time .time (), self .expiry ]
600
613
601
614
# channel_keys does not contain a single redis key more than once
602
615
connection = self .connection (connection_index )
@@ -611,7 +624,7 @@ async def group_send(self, group, message):
611
624
group ,
612
625
)
613
626
614
- def _map_channel_keys_to_connection (self , channel_names , message ):
627
+ def _map_channel_keys_to_connection (self , channel_names , messages ):
615
628
"""
616
629
For a list of channel names, GET
617
630
@@ -626,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
626
639
# Connection dict keyed by index to list of redis keys mapped on that index
627
640
connection_to_channel_keys = collections .defaultdict (list )
628
641
# Message dict maps redis key to the message that needs to be send on that key
629
- channel_key_to_message = dict ( )
642
+ channel_key_to_message = collections . defaultdict ( list )
630
643
# Channel key mapped to its capacity
631
644
channel_key_to_capacity = dict ()
632
645
@@ -640,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
640
653
# Have we come across the same redis key?
641
654
if channel_key not in channel_key_to_message :
642
655
# If not, fill the corresponding dicts
643
- message = dict (message .items ())
644
- message ["__asgi_channel__" ] = [channel ]
645
- channel_key_to_message [channel_key ] = message
656
+ for message in messages :
657
+ message = dict (message .items ())
658
+ message ["__asgi_channel__" ] = [channel ]
659
+ channel_key_to_message [channel_key ].append (message )
646
660
channel_key_to_capacity [channel_key ] = self .get_capacity (channel )
647
661
idx = self .consistent_hash (channel_non_local_name )
648
662
connection_to_channel_keys [idx ].append (channel_key )
649
663
else :
650
664
# Yes, Append the channel in message dict
651
- channel_key_to_message [channel_key ]["__asgi_channel__" ].append (channel )
665
+ for message in channel_key_to_message [channel_key ]:
666
+ message ["__asgi_channel__" ].append (channel )
652
667
653
668
# Now that we know what message needs to be send on a redis key we serialize it
654
669
for key , value in channel_key_to_message .items ():
655
670
# Serialize the message stored for each redis key
656
- channel_key_to_message [key ] = self .serialize (value )
671
+ for idx , message in enumerate (value ):
672
+ channel_key_to_message [key ][idx ] = self .serialize (message )
657
673
658
674
return (
659
675
connection_to_channel_keys ,
0 commit comments