Unverified Commit 93ee0fb6 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Rework queues (#1112)

* rework how pubsub handlers/queues are stores and activated
- switch to linked-list for queues
- use alloc-free path for all handlers invokes

* defer writing to the queues until we've released the subscriber lock
parent 57c75d18
...@@ -86,11 +86,9 @@ internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber paren ...@@ -86,11 +86,9 @@ internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber paren
SingleReader = false, SingleReader = false,
AllowSynchronousContinuations = false, AllowSynchronousContinuations = false,
}; };
internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags);
internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags);
#pragma warning disable RCS1231 // Make parameter ref read-only. - uses as a delegate for Action<RedisChannel, RedisValue> #pragma warning disable RCS1231 // Make parameter ref read-only. - uses as a delegate for Action<RedisChannel, RedisValue>
private void HandleMessage(RedisChannel channel, RedisValue value) private void Write(in RedisChannel channel, in RedisValue value)
#pragma warning restore RCS1231 // Make parameter ref read-only. #pragma warning restore RCS1231 // Make parameter ref read-only.
{ {
var writer = _queue.Writer; var writer = _queue.Writer;
...@@ -170,6 +168,20 @@ private async Task OnMessageSyncImpl() ...@@ -170,6 +168,20 @@ private async Task OnMessageSyncImpl()
} }
} }
internal static void Combine(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue != null)
{
// insert at the start of the linked-list
ChannelMessageQueue old;
do
{
old = Volatile.Read(ref head);
queue._next = old;
} while (Interlocked.CompareExchange(ref head, queue, old) != old);
}
}
/// <summary> /// <summary>
/// Create a message loop that processes messages sequentially. /// Create a message loop that processes messages sequentially.
/// </summary> /// </summary>
...@@ -182,6 +194,75 @@ public void OnMessage(Func<ChannelMessage, Task> handler) ...@@ -182,6 +194,75 @@ public void OnMessage(Func<ChannelMessage, Task> handler)
state => ((ChannelMessageQueue)state).OnMessageAsyncImpl().RedisFireAndForget(), this); state => ((ChannelMessageQueue)state).OnMessageAsyncImpl().RedisFireAndForget(), this);
} }
internal static void Remove(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue == null) return;
bool found;
do // if we fail due to a conflict, re-do from start
{
var current = Volatile.Read(ref head);
if (current == null) return; // no queue? nothing to do
if (current == queue)
{
found = true;
// found at the head - then we need to change the head
if (Interlocked.CompareExchange(ref head, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
}
else
{
ChannelMessageQueue previous = current;
current = Volatile.Read(ref previous._next);
found = false;
do
{
if (current == queue)
{
found = true;
// found it, not at the head; remove the node
if (Interlocked.CompareExchange(ref previous._next, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
else
{
break; // exit the inner loop, and repeat the outer loop
}
}
previous = current;
current = Volatile.Read(ref previous._next);
} while (current != null);
}
} while (found);
}
internal static int Count(ref ChannelMessageQueue head)
{
var current = Volatile.Read(ref head);
int count = 0;
while (current != null)
{
count++;
current = Volatile.Read(ref current._next);
}
return count;
}
internal static void WriteAll(ref ChannelMessageQueue head, in RedisChannel channel, in RedisValue message)
{
var current = Volatile.Read(ref head);
while (current != null)
{
current.Write(channel, message);
current = Volatile.Read(ref current._next);
}
}
private ChannelMessageQueue _next;
private async Task OnMessageAsyncImpl() private async Task OnMessageAsyncImpl()
{ {
var handler = (Func<ChannelMessage, Task>)_onMessageHandler; var handler = (Func<ChannelMessage, Task>)_onMessageHandler;
...@@ -205,14 +286,13 @@ private async Task OnMessageAsyncImpl() ...@@ -205,14 +286,13 @@ private async Task OnMessageAsyncImpl()
} }
} }
internal static void MarkCompleted(Action<RedisChannel, RedisValue> handler) internal static void MarkAllCompleted(ref ChannelMessageQueue head)
{ {
if (handler != null) var current = Interlocked.Exchange(ref head, null);
while (current != null)
{ {
foreach (Action<RedisChannel, RedisValue> sub in handler.GetInvocationList()) current.MarkCompleted();
{ current = Volatile.Read(ref current._next);
if (sub.Target is ChannelMessageQueue queue) queue.MarkCompleted();
}
} }
} }
...@@ -228,7 +308,7 @@ internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = Comma ...@@ -228,7 +308,7 @@ internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = Comma
_parent = null; _parent = null;
if (parent != null) if (parent != null)
{ {
parent.UnsubscribeAsync(Channel, HandleMessage, flags); parent.UnsubscribeAsync(Channel, null, this, flags);
} }
_queue.Writer.TryComplete(error); _queue.Writer.TryComplete(error);
} }
...@@ -239,24 +319,11 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl ...@@ -239,24 +319,11 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl
_parent = null; _parent = null;
if (parent != null) if (parent != null)
{ {
await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ForAwait(); await parent.UnsubscribeAsync(Channel, null, this, flags).ForAwait();
} }
_queue.Writer.TryComplete(error); _queue.Writer.TryComplete(error);
} }
internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{
try
{
return handler?.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage);
}
catch
{
return false;
}
}
/// <summary> /// <summary>
/// Stop receiving messages on this channel. /// Stop receiving messages on this channel.
/// </summary> /// </summary>
......
using System; using System;
using System.Text; using System.Text;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
...@@ -7,15 +8,14 @@ internal sealed class MessageCompletable : ICompletable ...@@ -7,15 +8,14 @@ internal sealed class MessageCompletable : ICompletable
{ {
private readonly RedisChannel channel; private readonly RedisChannel channel;
private readonly Action<RedisChannel, RedisValue> syncHandler, asyncHandler; private readonly Action<RedisChannel, RedisValue> handler;
private readonly RedisValue message; private readonly RedisValue message;
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> syncHandler, Action<RedisChannel, RedisValue> asyncHandler) public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> handler)
{ {
this.channel = channel; this.channel = channel;
this.message = message; this.message = message;
this.syncHandler = syncHandler; this.handler = handler;
this.asyncHandler = asyncHandler;
} }
public override string ToString() => (string)channel; public override string ToString() => (string)channel;
...@@ -24,13 +24,19 @@ public bool TryComplete(bool isAsync) ...@@ -24,13 +24,19 @@ public bool TryComplete(bool isAsync)
{ {
if (isAsync) if (isAsync)
{ {
if (asyncHandler != null) if (handler != null)
{ {
ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription"); ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in asyncHandler.GetInvocationList()) if (handler.IsSingle())
{ {
try { sub.Invoke(channel, message); } try { handler(channel, message); } catch { }
catch { } }
else
{
foreach (var sub in handler.AsEnumerable())
{
try { sub.Invoke(channel, message); } catch { }
}
} }
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription"); ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription");
} }
...@@ -38,17 +44,7 @@ public bool TryComplete(bool isAsync) ...@@ -38,17 +44,7 @@ public bool TryComplete(bool isAsync)
} }
else else
{ {
if (syncHandler != null) return handler == null; // anything async to do?
{
ConnectionMultiplexer.TraceWithoutContext("Invoking (sync)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in syncHandler.GetInvocationList())
{
try { sub.Invoke(channel, message); }
catch { }
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription");
}
return asyncHandler == null; // anything async to do?
} }
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
...@@ -24,10 +25,16 @@ internal static void CompleteAsWorker(ICompletable completable) ...@@ -24,10 +25,16 @@ internal static void CompleteAsWorker(ICompletable completable)
if (handler == null) return true; if (handler == null) return true;
if (isAsync) if (isAsync)
{ {
foreach (EventHandler<T> sub in handler.GetInvocationList()) if (handler.IsSingle())
{ {
try { sub.Invoke(sender, args); } try { handler(sender, args); } catch { }
catch { } }
else
{
foreach (EventHandler<T> sub in handler.AsEnumerable())
{
try { sub(sender, args); } catch { }
}
} }
return true; return true;
} }
...@@ -37,27 +44,39 @@ internal static void CompleteAsWorker(ICompletable completable) ...@@ -37,27 +44,39 @@ internal static void CompleteAsWorker(ICompletable completable)
} }
} }
internal Task AddSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags, object asyncState) internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out int queues)
{
if (handler != null)
{ {
bool asAsync = !ChannelMessageQueue.IsOneOf(handler); Subscription sub;
lock (subscriptions) lock (subscriptions)
{ {
if (subscriptions.TryGetValue(channel, out Subscription sub)) if (!subscriptions.TryGetValue(channel, out sub)) sub = null;
}
if (sub != null)
{ {
sub.Add(asAsync, handler); sub.GetSubscriberCounts(out handlers, out queues);
return true;
} }
else handlers = queues = 0;
return false;
}
internal Task AddSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags, object asyncState)
{
Task task = null;
if (handler != null | queue != null)
{
lock (subscriptions)
{
if (!subscriptions.TryGetValue(channel, out Subscription sub))
{ {
sub = new Subscription(asAsync, handler); sub = new Subscription();
subscriptions.Add(channel, sub); subscriptions.Add(channel, sub);
var task = sub.SubscribeToServer(this, channel, flags, asyncState, false); task = sub.SubscribeToServer(this, channel, flags, asyncState, false);
if (task != null) return task;
} }
sub.Add(handler, queue);
} }
} }
return CompletedTask<bool>.Default(asyncState); return task ?? CompletedTask<bool>.Default(asyncState);
} }
internal ServerEndPoint GetSubscribedServer(in RedisChannel channel) internal ServerEndPoint GetSubscribedServer(in RedisChannel channel)
...@@ -78,13 +97,16 @@ internal ServerEndPoint GetSubscribedServer(in RedisChannel channel) ...@@ -78,13 +97,16 @@ internal ServerEndPoint GetSubscribedServer(in RedisChannel channel)
internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload) internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload)
{ {
ICompletable completable = null; ICompletable completable = null;
ChannelMessageQueue queues = null;
Subscription sub;
lock (subscriptions) lock (subscriptions)
{ {
if (subscriptions.TryGetValue(subscription, out Subscription sub)) if (subscriptions.TryGetValue(subscription, out sub))
{ {
completable = sub.ForInvoke(channel, payload); completable = sub.ForInvoke(channel, payload, out queues);
} }
} }
if (queues != null) ChannelMessageQueue.WriteAll(ref queues, channel, payload);
if (completable != null && !completable.TryComplete(false)) ConnectionMultiplexer.CompleteAsWorker(completable); if (completable != null && !completable.TryComplete(false)) ConnectionMultiplexer.CompleteAsWorker(completable);
} }
...@@ -104,7 +126,7 @@ internal Task RemoveAllSubscriptions(CommandFlags flags, object asyncState) ...@@ -104,7 +126,7 @@ internal Task RemoveAllSubscriptions(CommandFlags flags, object asyncState)
return last ?? CompletedTask<bool>.Default(asyncState); return last ?? CompletedTask<bool>.Default(asyncState);
} }
internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags, object asyncState) internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags, object asyncState)
{ {
Task task = null; Task task = null;
lock (subscriptions) lock (subscriptions)
...@@ -112,15 +134,14 @@ internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, R ...@@ -112,15 +134,14 @@ internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, R
if (subscriptions.TryGetValue(channel, out Subscription sub)) if (subscriptions.TryGetValue(channel, out Subscription sub))
{ {
bool remove; bool remove;
if (handler == null) // blanket wipe if (handler == null & queue == null) // blanket wipe
{ {
sub.MarkCompleted(); sub.MarkCompleted();
remove = true; remove = true;
} }
else else
{ {
bool asAsync = !ChannelMessageQueue.IsOneOf(handler); remove = sub.Remove(handler, queue);
remove = sub.Remove(asAsync, handler);
} }
if (remove) if (remove)
{ {
...@@ -168,44 +189,34 @@ internal long ValidateSubscriptions() ...@@ -168,44 +189,34 @@ internal long ValidateSubscriptions()
internal sealed class Subscription internal sealed class Subscription
{ {
private Action<RedisChannel, RedisValue> _asyncHandler, _syncHandler; private Action<RedisChannel, RedisValue> _handlers;
private ChannelMessageQueue _queues;
private ServerEndPoint owner; private ServerEndPoint owner;
public Subscription(bool asAsync, Action<RedisChannel, RedisValue> value) public void Add(Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue)
{ {
if (asAsync) _asyncHandler = value; if (handler != null) _handlers += handler;
else _syncHandler = value; if (queue != null) ChannelMessageQueue.Combine(ref _queues, queue);
} }
public void Add(bool asAsync, Action<RedisChannel, RedisValue> value) public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message, out ChannelMessageQueue queues)
{ {
if (asAsync) _asyncHandler += value; var handlers = _handlers;
else _syncHandler += value; queues = Volatile.Read(ref _queues);
} return handlers == null ? null : new MessageCompletable(channel, message, handlers);
public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message)
{
var syncHandler = _syncHandler;
var asyncHandler = _asyncHandler;
return (syncHandler == null && asyncHandler == null) ? null : new MessageCompletable(channel, message, syncHandler, asyncHandler);
} }
internal void MarkCompleted() internal void MarkCompleted()
{ {
_asyncHandler = null; _handlers = null;
var oldSync = _syncHandler; ChannelMessageQueue.MarkAllCompleted(ref _queues);
_syncHandler = null;
ChannelMessageQueue.MarkCompleted(oldSync);
} }
public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value) public bool Remove(Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue)
{
if (value != null)
{ {
if (asAsync) _asyncHandler -= value; if (handler != null) _handlers -= handler;
else _syncHandler -= value; if (queue != null) ChannelMessageQueue.Remove(ref _queues, queue);
} return _handlers == null & _queues == null;
return _syncHandler == null && _asyncHandler == null;
} }
public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall) public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
...@@ -316,6 +327,25 @@ internal bool Validate(ConnectionMultiplexer multiplexer, in RedisChannel channe ...@@ -316,6 +327,25 @@ internal bool Validate(ConnectionMultiplexer multiplexer, in RedisChannel channe
} }
return changed; return changed;
} }
internal void GetSubscriberCounts(out int handlers, out int queues)
{
queues = ChannelMessageQueue.Count(ref _queues);
var tmp = _handlers;
if (tmp == null)
{
handlers = 0;
}
else if (tmp.IsSingle())
{
handlers = 1;
}
else
{
handlers = 0;
foreach (var sub in tmp.AsEnumerable()) { handlers++; }
}
}
} }
internal string GetConnectionName(EndPoint endPoint, ConnectionType connectionType) internal string GetConnectionName(EndPoint endPoint, ConnectionType connectionType)
...@@ -437,30 +467,39 @@ public Task<long> PublishAsync(RedisChannel channel, RedisValue message, Command ...@@ -437,30 +467,39 @@ public Task<long> PublishAsync(RedisChannel channel, RedisValue message, Command
return ExecuteAsync(msg, ResultProcessor.Int64); return ExecuteAsync(msg, ResultProcessor.Int64);
} }
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None) void ISubscriber.Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
=> Subscribe(channel, handler, null, flags);
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags)
{ {
var task = SubscribeAsync(channel, handler, flags); var task = SubscribeAsync(channel, handler, queue, flags);
if ((flags & CommandFlags.FireAndForget) == 0) Wait(task); if ((flags & CommandFlags.FireAndForget) == 0) Wait(task);
} }
public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None) public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{ {
var c = new ChannelMessageQueue(channel, this); var queue = new ChannelMessageQueue(channel, this);
c.Subscribe(flags); Subscribe(channel, null, queue, flags);
return c; return queue;
} }
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None) Task ISubscriber.SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
=> SubscribeAsync(channel, handler, null, flags);
public Task SubscribeAsync(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags)
{ {
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
return multiplexer.AddSubscription(channel, handler, flags, asyncState); return multiplexer.AddSubscription(channel, handler, queue, flags, asyncState);
} }
internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out int queues)
=> multiplexer.GetSubscriberCounts(channel, out handlers, out queues);
public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None) public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{ {
var c = new ChannelMessageQueue(channel, this); var queue = new ChannelMessageQueue(channel, this);
await c.SubscribeAsync(flags).ForAwait(); await SubscribeAsync(channel, null, queue, flags).ForAwait();
return c; return queue;
} }
public EndPoint SubscribedEndpoint(RedisChannel channel) public EndPoint SubscribedEndpoint(RedisChannel channel)
...@@ -469,9 +508,11 @@ public EndPoint SubscribedEndpoint(RedisChannel channel) ...@@ -469,9 +508,11 @@ public EndPoint SubscribedEndpoint(RedisChannel channel)
return server?.EndPoint; return server?.EndPoint;
} }
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None) void ISubscriber.Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
=> Unsubscribe(channel, handler, null, flags);
public void Unsubscribe(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags)
{ {
var task = UnsubscribeAsync(channel, handler, flags); var task = UnsubscribeAsync(channel, handler, queue, flags);
if ((flags & CommandFlags.FireAndForget) == 0) Wait(task); if ((flags & CommandFlags.FireAndForget) == 0) Wait(task);
} }
...@@ -486,10 +527,12 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None) ...@@ -486,10 +527,12 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None)
return multiplexer.RemoveAllSubscriptions(flags, asyncState); return multiplexer.RemoveAllSubscriptions(flags, asyncState);
} }
public Task UnsubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None) Task ISubscriber.UnsubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
=> UnsubscribeAsync(channel, handler, null, flags);
public Task UnsubscribeAsync(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue, CommandFlags flags)
{ {
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
return multiplexer.RemoveSubscription(channel, handler, flags, asyncState); return multiplexer.RemoveSubscription(channel, handler, queue, flags, asyncState);
} }
} }
} }
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Xunit; using Xunit;
using Xunit.Abstractions; using Xunit.Abstractions;
...@@ -10,6 +11,14 @@ public class Issue1101 : TestBase ...@@ -10,6 +11,14 @@ public class Issue1101 : TestBase
{ {
public Issue1101(ITestOutputHelper output) : base(output) { } public Issue1101(ITestOutputHelper output) : base(output) { }
static void AssertCounts(ISubscriber pubsub, in RedisChannel channel,
bool has, int handlers, int queues)
{
var aHas = ((RedisSubscriber)pubsub).GetSubscriberCounts(channel, out var ah, out var aq);
Assert.Equal(has, aHas);
Assert.Equal(handlers, ah);
Assert.Equal(queues, aq);
}
[Fact] [Fact]
public async Task ExecuteWithUnsubscribeViaChannel() public async Task ExecuteWithUnsubscribeViaChannel()
{ {
...@@ -17,15 +26,20 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -17,15 +26,20 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -35,9 +49,10 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -35,9 +49,10 @@ public async Task ExecuteWithUnsubscribeViaChannel()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await channel.UnsubscribeAsync(); await first.UnsubscribeAsync();
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "def"); await pubsub.PublishAsync(name, "def");
await Task.Delay(100); await Task.Delay(100);
...@@ -45,10 +60,29 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -45,10 +60,29 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, true, 0, 1);
await second.UnsubscribeAsync();
await Task.Delay(100);
await pubsub.PublishAsync(name, "ghi");
await Task.Delay(100);
lock (values)
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
} }
} }
...@@ -59,15 +93,21 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -59,15 +93,21 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -77,7 +117,8 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -77,7 +117,8 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAsync(name); await pubsub.UnsubscribeAsync(name);
await Task.Delay(100); await Task.Delay(100);
...@@ -87,10 +128,13 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -87,10 +128,13 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
} }
} }
...@@ -101,15 +145,20 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -101,15 +145,20 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -119,7 +168,8 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -119,7 +168,8 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAllAsync(); await pubsub.UnsubscribeAllAsync();
await Task.Delay(100); await Task.Delay(100);
...@@ -129,10 +179,13 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -129,10 +179,13 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment