Unverified Commit 93ee0fb6 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Rework queues (#1112)

* rework how pubsub handlers/queues are stores and activated
- switch to linked-list for queues
- use alloc-free path for all handlers invokes

* defer writing to the queues until we've released the subscriber lock
parent 57c75d18
...@@ -86,11 +86,9 @@ internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber paren ...@@ -86,11 +86,9 @@ internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber paren
SingleReader = false, SingleReader = false,
AllowSynchronousContinuations = false, AllowSynchronousContinuations = false,
}; };
internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags);
internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags);
#pragma warning disable RCS1231 // Make parameter ref read-only. - uses as a delegate for Action<RedisChannel, RedisValue> #pragma warning disable RCS1231 // Make parameter ref read-only. - uses as a delegate for Action<RedisChannel, RedisValue>
private void HandleMessage(RedisChannel channel, RedisValue value) private void Write(in RedisChannel channel, in RedisValue value)
#pragma warning restore RCS1231 // Make parameter ref read-only. #pragma warning restore RCS1231 // Make parameter ref read-only.
{ {
var writer = _queue.Writer; var writer = _queue.Writer;
...@@ -170,6 +168,20 @@ private async Task OnMessageSyncImpl() ...@@ -170,6 +168,20 @@ private async Task OnMessageSyncImpl()
} }
} }
internal static void Combine(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue != null)
{
// insert at the start of the linked-list
ChannelMessageQueue old;
do
{
old = Volatile.Read(ref head);
queue._next = old;
} while (Interlocked.CompareExchange(ref head, queue, old) != old);
}
}
/// <summary> /// <summary>
/// Create a message loop that processes messages sequentially. /// Create a message loop that processes messages sequentially.
/// </summary> /// </summary>
...@@ -182,6 +194,75 @@ public void OnMessage(Func<ChannelMessage, Task> handler) ...@@ -182,6 +194,75 @@ public void OnMessage(Func<ChannelMessage, Task> handler)
state => ((ChannelMessageQueue)state).OnMessageAsyncImpl().RedisFireAndForget(), this); state => ((ChannelMessageQueue)state).OnMessageAsyncImpl().RedisFireAndForget(), this);
} }
internal static void Remove(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue == null) return;
bool found;
do // if we fail due to a conflict, re-do from start
{
var current = Volatile.Read(ref head);
if (current == null) return; // no queue? nothing to do
if (current == queue)
{
found = true;
// found at the head - then we need to change the head
if (Interlocked.CompareExchange(ref head, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
}
else
{
ChannelMessageQueue previous = current;
current = Volatile.Read(ref previous._next);
found = false;
do
{
if (current == queue)
{
found = true;
// found it, not at the head; remove the node
if (Interlocked.CompareExchange(ref previous._next, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
else
{
break; // exit the inner loop, and repeat the outer loop
}
}
previous = current;
current = Volatile.Read(ref previous._next);
} while (current != null);
}
} while (found);
}
internal static int Count(ref ChannelMessageQueue head)
{
var current = Volatile.Read(ref head);
int count = 0;
while (current != null)
{
count++;
current = Volatile.Read(ref current._next);
}
return count;
}
internal static void WriteAll(ref ChannelMessageQueue head, in RedisChannel channel, in RedisValue message)
{
var current = Volatile.Read(ref head);
while (current != null)
{
current.Write(channel, message);
current = Volatile.Read(ref current._next);
}
}
private ChannelMessageQueue _next;
private async Task OnMessageAsyncImpl() private async Task OnMessageAsyncImpl()
{ {
var handler = (Func<ChannelMessage, Task>)_onMessageHandler; var handler = (Func<ChannelMessage, Task>)_onMessageHandler;
...@@ -205,14 +286,13 @@ private async Task OnMessageAsyncImpl() ...@@ -205,14 +286,13 @@ private async Task OnMessageAsyncImpl()
} }
} }
internal static void MarkCompleted(Action<RedisChannel, RedisValue> handler) internal static void MarkAllCompleted(ref ChannelMessageQueue head)
{ {
if (handler != null) var current = Interlocked.Exchange(ref head, null);
while (current != null)
{ {
foreach (Action<RedisChannel, RedisValue> sub in handler.GetInvocationList()) current.MarkCompleted();
{ current = Volatile.Read(ref current._next);
if (sub.Target is ChannelMessageQueue queue) queue.MarkCompleted();
}
} }
} }
...@@ -228,7 +308,7 @@ internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = Comma ...@@ -228,7 +308,7 @@ internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = Comma
_parent = null; _parent = null;
if (parent != null) if (parent != null)
{ {
parent.UnsubscribeAsync(Channel, HandleMessage, flags); parent.UnsubscribeAsync(Channel, null, this, flags);
} }
_queue.Writer.TryComplete(error); _queue.Writer.TryComplete(error);
} }
...@@ -239,24 +319,11 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl ...@@ -239,24 +319,11 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl
_parent = null; _parent = null;
if (parent != null) if (parent != null)
{ {
await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ForAwait(); await parent.UnsubscribeAsync(Channel, null, this, flags).ForAwait();
} }
_queue.Writer.TryComplete(error); _queue.Writer.TryComplete(error);
} }
internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{
try
{
return handler?.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage);
}
catch
{
return false;
}
}
/// <summary> /// <summary>
/// Stop receiving messages on this channel. /// Stop receiving messages on this channel.
/// </summary> /// </summary>
......
using System; using System;
using System.Text; using System.Text;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
...@@ -7,15 +8,14 @@ internal sealed class MessageCompletable : ICompletable ...@@ -7,15 +8,14 @@ internal sealed class MessageCompletable : ICompletable
{ {
private readonly RedisChannel channel; private readonly RedisChannel channel;
private readonly Action<RedisChannel, RedisValue> syncHandler, asyncHandler; private readonly Action<RedisChannel, RedisValue> handler;
private readonly RedisValue message; private readonly RedisValue message;
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> syncHandler, Action<RedisChannel, RedisValue> asyncHandler) public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> handler)
{ {
this.channel = channel; this.channel = channel;
this.message = message; this.message = message;
this.syncHandler = syncHandler; this.handler = handler;
this.asyncHandler = asyncHandler;
} }
public override string ToString() => (string)channel; public override string ToString() => (string)channel;
...@@ -24,13 +24,19 @@ public bool TryComplete(bool isAsync) ...@@ -24,13 +24,19 @@ public bool TryComplete(bool isAsync)
{ {
if (isAsync) if (isAsync)
{ {
if (asyncHandler != null) if (handler != null)
{ {
ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription"); ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in asyncHandler.GetInvocationList()) if (handler.IsSingle())
{ {
try { sub.Invoke(channel, message); } try { handler(channel, message); } catch { }
catch { } }
else
{
foreach (var sub in handler.AsEnumerable())
{
try { sub.Invoke(channel, message); } catch { }
}
} }
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription"); ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription");
} }
...@@ -38,17 +44,7 @@ public bool TryComplete(bool isAsync) ...@@ -38,17 +44,7 @@ public bool TryComplete(bool isAsync)
} }
else else
{ {
if (syncHandler != null) return handler == null; // anything async to do?
{
ConnectionMultiplexer.TraceWithoutContext("Invoking (sync)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in syncHandler.GetInvocationList())
{
try { sub.Invoke(channel, message); }
catch { }
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription");
}
return asyncHandler == null; // anything async to do?
} }
} }
......
This diff is collapsed.
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Xunit; using Xunit;
using Xunit.Abstractions; using Xunit.Abstractions;
...@@ -10,6 +11,14 @@ public class Issue1101 : TestBase ...@@ -10,6 +11,14 @@ public class Issue1101 : TestBase
{ {
public Issue1101(ITestOutputHelper output) : base(output) { } public Issue1101(ITestOutputHelper output) : base(output) { }
static void AssertCounts(ISubscriber pubsub, in RedisChannel channel,
bool has, int handlers, int queues)
{
var aHas = ((RedisSubscriber)pubsub).GetSubscriberCounts(channel, out var ah, out var aq);
Assert.Equal(has, aHas);
Assert.Equal(handlers, ah);
Assert.Equal(queues, aq);
}
[Fact] [Fact]
public async Task ExecuteWithUnsubscribeViaChannel() public async Task ExecuteWithUnsubscribeViaChannel()
{ {
...@@ -17,15 +26,20 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -17,15 +26,20 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -35,9 +49,10 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -35,9 +49,10 @@ public async Task ExecuteWithUnsubscribeViaChannel()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await channel.UnsubscribeAsync(); await first.UnsubscribeAsync();
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "def"); await pubsub.PublishAsync(name, "def");
await Task.Delay(100); await Task.Delay(100);
...@@ -45,10 +60,29 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -45,10 +60,29 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, true, 0, 1);
await second.UnsubscribeAsync();
await Task.Delay(100);
await pubsub.PublishAsync(name, "ghi");
await Task.Delay(100);
lock (values)
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
} }
} }
...@@ -59,15 +93,21 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -59,15 +93,21 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -77,7 +117,8 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -77,7 +117,8 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAsync(name); await pubsub.UnsubscribeAsync(name);
await Task.Delay(100); await Task.Delay(100);
...@@ -87,10 +128,13 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -87,10 +128,13 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
} }
} }
...@@ -101,15 +145,20 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -101,15 +145,20 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{ {
RedisChannel name = Me(); RedisChannel name = Me();
var pubsub = muxer.GetSubscriber(); var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data // subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name); var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => int i = 0;
first.OnMessage(x =>
{ {
lock (values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100); await Task.Delay(100);
await pubsub.PublishAsync(name, "abc"); await pubsub.PublishAsync(name, "abc");
await Task.Delay(100); await Task.Delay(100);
...@@ -119,7 +168,8 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -119,7 +168,8 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed"); Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAllAsync(); await pubsub.UnsubscribeAllAsync();
await Task.Delay(100); await Task.Delay(100);
...@@ -129,10 +179,13 @@ public async Task ExecuteWithUnsubscribeViaClearAll() ...@@ -129,10 +179,13 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{ {
Assert.Equal("abc", Assert.Single(values)); Assert.Equal("abc", Assert.Single(values));
} }
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed"); Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment