Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscribe to all primaries (for key-space-notifications) #2860

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/StackExchange.Redis/Interfaces/ISubscriber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ public interface ISubscriber : IRedis
/// <inheritdoc cref="Subscribe(RedisChannel, Action{RedisChannel, RedisValue}, CommandFlags)"/>
Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None);

/// <summary>
/// Subscribe to perform some operation when a message to any primary node is broadcast, without any guarantee of ordered handling.
/// This is most useful when addressing key space notifications. For user controlled pub/sub channels, it should be used with caution and further it is not advised.
/// </summary>
/// <param name="channel">The channel to subscribe to.</param>
/// <param name="handler">The handler to invoke when a message is received on <paramref name="channel"/>.</param>
/// <param name="flags">The command flags to use.</param>
/// <remarks>
/// See
/// <seealso href="https://redis.io/docs/latest/develop/use/keyspace-notifications/"/>,
/// <seealso href="https://redis.io/commands/subscribe"/>,
/// <seealso href="https://redis.io/commands/psubscribe"/>.
/// </remarks>
Task SubscribeAllPrimariesAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should infer "all primaries" (vs secondaries) from flags, like we would with regular subscribe, so this becomes SubscribeAll; I also wonder whether we should internalize this entirely, specifically against the keyspace events, because for regular keys:

  • non-'s' variants will get broadcast duplicates that will be hard to untangle
  • s variants don't make sense to attach to the entire cluster, since they are routed using shard rules

That would also have the advantage of not needing a separate API. I'm not certain on this - mostly thinking aloud here.


/// <summary>
/// Unsubscribe from a specified message channel on all primary nodes.
/// Note: if no handler is specified, the subscription is canceled regardless of the subscribers.
/// If a handler is specified, the subscription is only canceled if this handler is the last handler remaining against the channel.
/// This is used in combination with <see cref="SubscribeAllPrimariesAsync"/> and as mentioned there,
/// it is intended to use with key space notitications and not advised for user controlled pub/sub channels.
/// </summary>
/// <param name="channel">The channel that was subscribed to.</param>
/// <param name="handler">The handler to no longer invoke when a message is received on <paramref name="channel"/>.</param>
/// <param name="flags">The command flags to use.</param>
/// <remarks>
/// See
/// <seealso href="https://redis.io/docs/latest/develop/use/keyspace-notifications/"/>,
/// <seealso href="https://redis.io/commands/unsubscribe"/>,
/// <seealso href="https://redis.io/commands/punsubscribe"/>.
/// </remarks>
Task UnsubscribeAllPrimariesAsync(RedisChannel channel, Action<RedisChannel, RedisValue>? handler = null, CommandFlags flags = CommandFlags.None);

/// <summary>
/// Subscribe to perform some operation when a message to the preferred/active node is broadcast, as a queue that guarantees ordered handling.
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1893,4 +1893,5 @@ virtual StackExchange.Redis.RedisResult.Length.get -> int
virtual StackExchange.Redis.RedisResult.this[int index].get -> StackExchange.Redis.RedisResult!
StackExchange.Redis.ConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void
StackExchange.Redis.IConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void

StackExchange.Redis.ISubscriber.SubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action<StackExchange.Redis.RedisChannel, StackExchange.Redis.RedisValue>! handler, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task!
StackExchange.Redis.ISubscriber.UnsubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action<StackExchange.Redis.RedisChannel, StackExchange.Redis.RedisValue>? handler = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task!
82 changes: 82 additions & 0 deletions src/StackExchange.Redis/RedisSubscriber.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -432,6 +434,86 @@ public Task<bool> SubscribeAsync(RedisChannel channel, Action<RedisChannel, Redi
return EnsureSubscribedToServerAsync(sub, channel, flags, false);
}

Task ISubscriber.SubscribeAllPrimariesAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
=> SubscribeAllPrimariesAsync(channel, handler, null, flags);

public Task<bool> SubscribeAllPrimariesAsync(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags)
{
ThrowIfNull(channel);
if (handler == null && queue == null) { return CompletedTask<bool>.Default(null); }

var sub = multiplexer.GetOrAddSubscription(channel, flags);
sub.Add(handler, queue);
return EnsureSubscribedToPrimariesAsync(sub, channel, flags, false);
}

private Task<bool> EnsureSubscribedToPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's almost certainly something complex in here around channel reconnects... I'll need to dig

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oww, now i see..
we are missing the cases with failovers and etc..
i ll try to explore the cases. thanks.

{
if (sub.IsConnected) { return CompletedTask<bool>.Default(null); }

// TODO: Cleanup old hangers here?
sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection
var tasks = new List<Task<bool>>();
foreach (var server in multiplexer.GetServerSnapshot())
{
if (!server.IsReplica)
{
var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall);
tasks.Add(ExecuteAsync(message, sub.Processor, server));
}
}

if (tasks.Count == 0)
{
return CompletedTask<bool>.Default(false);
}

// Create a new task that will collect all results and observe errors
return Task.Run(async () =>
{
// Wait for all tasks to complete
var results = await Task.WhenAll(tasks).ObserveErrors();
return results.All(result => result);
}).ObserveErrors();
}

Task ISubscriber.UnsubscribeAllPrimariesAsync(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, CommandFlags flags)
=> UnsubscribeAllPrimariesAsync(channel, handler, null, flags);

public Task<bool> UnsubscribeAllPrimariesAsync(in RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags)
{
ThrowIfNull(channel);
// Unregister the subscription handler/queue, and if that returns true (last handler removed), also disconnect from the server
return UnregisterSubscription(channel, handler, queue, out var sub)
? UnsubscribeFromPrimariesAsync(sub, channel, flags, asyncState, false)
: CompletedTask<bool>.Default(asyncState);
}

private Task<bool> UnsubscribeFromPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, object? asyncState, bool internalCall)
{
var tasks = new List<Task<bool>>();
foreach (var server in multiplexer.GetServerSnapshot())
{
if (!server.IsReplica)
{
var message = sub.GetMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall);
tasks.Add(multiplexer.ExecuteAsyncImpl(message, sub.Processor, asyncState, server));
}
}

if (tasks.Count == 0)
{
return CompletedTask<bool>.Default(false);
}

// Create a new task that will collect all results and observe errors
return Task.Run(async () =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we can unwrap a little of the extra task layer here, but I'm not too concerned - this is low throughput

{
// Wait for all tasks to complete
var results = await Task.WhenAll(tasks).ObserveErrors();
return results.All(result => result);
}).ObserveErrors();
}
public Task<bool> EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall)
{
if (sub.IsConnected) { return CompletedTask<bool>.Default(null); }
Expand Down
81 changes: 81 additions & 0 deletions tests/StackExchange.Redis.Tests/ClusterTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -746,4 +747,84 @@ public void ConnectIncludesSubscriber()
Assert.Equal(PhysicalBridge.State.ConnectedEstablished, server.SubscriptionConnectionState);
}
}

[Fact]
public async void SubscribeAllPrimariesAsync()
{
var receivedNofitications = new ConcurrentDictionary<string, int>();
var expectedNofitications = new HashSet<string>();
Action<RedisChannel, RedisValue> Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0;

var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true);
var db = redis.GetDatabase();
SwitchKeySpaceNofitications(redis);

var sub = redis.GetSubscriber();
var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern);
await sub.SubscribeAllPrimariesAsync(channel, Accept);

for (int i = 1; i < 3; i++)
{
await db.StringSetAsync($"k{i}", i);
expectedNofitications.Add($"set on channel __keyspace@0__:k{i}");
expectedNofitications.Add($"k{i} on channel __keyevent@0__:set");
}

// Wait for notifications to be processed
await Task.Delay(1000).ForAwait();
foreach (var notification in expectedNofitications)
{
Assert.Contains(notification, receivedNofitications);
}
SwitchKeySpaceNofitications(redis, false);

// Assert that all expected notifications are contained in received notifications
Assert.True(
expectedNofitications.IsSubsetOf(receivedNofitications.Keys),
$"Expected notifications were not received. Expected: {string.Join(", ", expectedNofitications)}, Received: {string.Join(", ", receivedNofitications)}");
}

[Fact]
public async void UnsubscribeAllPrimariesAsync()
{
var receivedNofitications = new ConcurrentDictionary<string, int>();
Action<RedisChannel, RedisValue> Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0;

var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true);
var db = redis.GetDatabase();
SwitchKeySpaceNofitications(redis);

var sub = redis.GetSubscriber();
var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern);
await sub.SubscribeAllPrimariesAsync(channel, Accept);

for (int i = 1; i < 3; i++)
{
await db.StringSetAsync($"k{i}", i);
}

// Wait for notifications to be processed
await Task.Delay(1000).ForAwait();
Assert.NotEmpty(receivedNofitications);
receivedNofitications.Clear();
await sub.UnsubscribeAllPrimariesAsync(channel, Accept);
for (int i = 1; i < 3; i++)
{
await db.StringSetAsync($"k{i}", i);
}
await Task.Delay(1000).ForAwait();
Assert.Empty(receivedNofitications);

SwitchKeySpaceNofitications(redis, false);
}

private static void SwitchKeySpaceNofitications(IInternalConnectionMultiplexer redis, bool enable = true)
{
string onOff = enable ? "KEA" : "";
foreach (var endpoint in redis.GetEndPoints())
{
var server = redis.GetServer(endpoint);
if (!server.IsReplica) server.ConfigSet("notify-keyspace-events", onOff);
}
}
}
Loading