Unverified Commit 93ee0fb6 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Rework queues (#1112)

* rework how pubsub handlers/queues are stores and activated
- switch to linked-list for queues
- use alloc-free path for all handlers invokes

* defer writing to the queues until we've released the subscriber lock
parent 57c75d18
......@@ -86,11 +86,9 @@ internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber paren
SingleReader = false,
AllowSynchronousContinuations = false,
};
internal void Subscribe(CommandFlags flags) => _parent.Subscribe(Channel, HandleMessage, flags);
internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(Channel, HandleMessage, flags);
#pragma warning disable RCS1231 // Make parameter ref read-only. - uses as a delegate for Action<RedisChannel, RedisValue>
private void HandleMessage(RedisChannel channel, RedisValue value)
private void Write(in RedisChannel channel, in RedisValue value)
#pragma warning restore RCS1231 // Make parameter ref read-only.
{
var writer = _queue.Writer;
......@@ -170,6 +168,20 @@ private async Task OnMessageSyncImpl()
}
}
internal static void Combine(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue != null)
{
// insert at the start of the linked-list
ChannelMessageQueue old;
do
{
old = Volatile.Read(ref head);
queue._next = old;
} while (Interlocked.CompareExchange(ref head, queue, old) != old);
}
}
/// <summary>
/// Create a message loop that processes messages sequentially.
/// </summary>
......@@ -182,6 +194,75 @@ public void OnMessage(Func<ChannelMessage, Task> handler)
state => ((ChannelMessageQueue)state).OnMessageAsyncImpl().RedisFireAndForget(), this);
}
internal static void Remove(ref ChannelMessageQueue head, ChannelMessageQueue queue)
{
if (queue == null) return;
bool found;
do // if we fail due to a conflict, re-do from start
{
var current = Volatile.Read(ref head);
if (current == null) return; // no queue? nothing to do
if (current == queue)
{
found = true;
// found at the head - then we need to change the head
if (Interlocked.CompareExchange(ref head, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
}
else
{
ChannelMessageQueue previous = current;
current = Volatile.Read(ref previous._next);
found = false;
do
{
if (current == queue)
{
found = true;
// found it, not at the head; remove the node
if (Interlocked.CompareExchange(ref previous._next, Volatile.Read(ref current._next), current) == current)
{
return; // success
}
else
{
break; // exit the inner loop, and repeat the outer loop
}
}
previous = current;
current = Volatile.Read(ref previous._next);
} while (current != null);
}
} while (found);
}
internal static int Count(ref ChannelMessageQueue head)
{
var current = Volatile.Read(ref head);
int count = 0;
while (current != null)
{
count++;
current = Volatile.Read(ref current._next);
}
return count;
}
internal static void WriteAll(ref ChannelMessageQueue head, in RedisChannel channel, in RedisValue message)
{
var current = Volatile.Read(ref head);
while (current != null)
{
current.Write(channel, message);
current = Volatile.Read(ref current._next);
}
}
private ChannelMessageQueue _next;
private async Task OnMessageAsyncImpl()
{
var handler = (Func<ChannelMessage, Task>)_onMessageHandler;
......@@ -205,14 +286,13 @@ private async Task OnMessageAsyncImpl()
}
}
internal static void MarkCompleted(Action<RedisChannel, RedisValue> handler)
internal static void MarkAllCompleted(ref ChannelMessageQueue head)
{
if (handler != null)
var current = Interlocked.Exchange(ref head, null);
while (current != null)
{
foreach (Action<RedisChannel, RedisValue> sub in handler.GetInvocationList())
{
if (sub.Target is ChannelMessageQueue queue) queue.MarkCompleted();
}
current.MarkCompleted();
current = Volatile.Read(ref current._next);
}
}
......@@ -228,7 +308,7 @@ internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = Comma
_parent = null;
if (parent != null)
{
parent.UnsubscribeAsync(Channel, HandleMessage, flags);
parent.UnsubscribeAsync(Channel, null, this, flags);
}
_queue.Writer.TryComplete(error);
}
......@@ -239,24 +319,11 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl
_parent = null;
if (parent != null)
{
await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ForAwait();
await parent.UnsubscribeAsync(Channel, null, this, flags).ForAwait();
}
_queue.Writer.TryComplete(error);
}
internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{
try
{
return handler?.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage);
}
catch
{
return false;
}
}
/// <summary>
/// Stop receiving messages on this channel.
/// </summary>
......
using System;
using System.Text;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis
{
......@@ -7,15 +8,14 @@ internal sealed class MessageCompletable : ICompletable
{
private readonly RedisChannel channel;
private readonly Action<RedisChannel, RedisValue> syncHandler, asyncHandler;
private readonly Action<RedisChannel, RedisValue> handler;
private readonly RedisValue message;
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> syncHandler, Action<RedisChannel, RedisValue> asyncHandler)
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> handler)
{
this.channel = channel;
this.message = message;
this.syncHandler = syncHandler;
this.asyncHandler = asyncHandler;
this.handler = handler;
}
public override string ToString() => (string)channel;
......@@ -24,13 +24,19 @@ public bool TryComplete(bool isAsync)
{
if (isAsync)
{
if (asyncHandler != null)
if (handler != null)
{
ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in asyncHandler.GetInvocationList())
if (handler.IsSingle())
{
try { sub.Invoke(channel, message); }
catch { }
try { handler(channel, message); } catch { }
}
else
{
foreach (var sub in handler.AsEnumerable())
{
try { sub.Invoke(channel, message); } catch { }
}
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription");
}
......@@ -38,17 +44,7 @@ public bool TryComplete(bool isAsync)
}
else
{
if (syncHandler != null)
{
ConnectionMultiplexer.TraceWithoutContext("Invoking (sync)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in syncHandler.GetInvocationList())
{
try { sub.Invoke(channel, message); }
catch { }
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription");
}
return asyncHandler == null; // anything async to do?
return handler == null; // anything async to do?
}
}
......
This diff is collapsed.
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
......@@ -10,6 +11,14 @@ public class Issue1101 : TestBase
{
public Issue1101(ITestOutputHelper output) : base(output) { }
static void AssertCounts(ISubscriber pubsub, in RedisChannel channel,
bool has, int handlers, int queues)
{
var aHas = ((RedisSubscriber)pubsub).GetSubscriberCounts(channel, out var ah, out var aq);
Assert.Equal(has, aHas);
Assert.Equal(handlers, ah);
Assert.Equal(queues, aq);
}
[Fact]
public async Task ExecuteWithUnsubscribeViaChannel()
{
......@@ -17,15 +26,20 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{
RedisChannel name = Me();
var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name);
var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>();
channel.OnMessage(x =>
int i = 0;
first.OnMessage(x =>
{
lock (values) { values.Add(x.Message); }
return Task.CompletedTask;
});
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100);
await pubsub.PublishAsync(name, "abc");
await Task.Delay(100);
......@@ -35,9 +49,10 @@ public async Task ExecuteWithUnsubscribeViaChannel()
}
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await channel.UnsubscribeAsync();
await first.UnsubscribeAsync();
await Task.Delay(100);
await pubsub.PublishAsync(name, "def");
await Task.Delay(100);
......@@ -45,10 +60,29 @@ public async Task ExecuteWithUnsubscribeViaChannel()
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, true, 0, 1);
await second.UnsubscribeAsync();
await Task.Delay(100);
await pubsub.PublishAsync(name, "ghi");
await Task.Delay(100);
lock (values)
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(2, Volatile.Read(ref i));
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
}
}
......@@ -59,15 +93,21 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{
RedisChannel name = Me();
var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name);
var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>();
channel.OnMessage(x =>
int i = 0;
first.OnMessage(x =>
{
lock (values) { values.Add(x.Message); }
return Task.CompletedTask;
});
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100);
await pubsub.PublishAsync(name, "abc");
await Task.Delay(100);
......@@ -77,7 +117,8 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
}
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAsync(name);
await Task.Delay(100);
......@@ -87,10 +128,13 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
}
}
......@@ -101,15 +145,20 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{
RedisChannel name = Me();
var pubsub = muxer.GetSubscriber();
AssertCounts(pubsub, name, false, 0, 0);
// subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name);
var first = await pubsub.SubscribeAsync(name);
var second = await pubsub.SubscribeAsync(name);
AssertCounts(pubsub, name, true, 0, 2);
List<string> values = new List<string>();
channel.OnMessage(x =>
int i = 0;
first.OnMessage(x =>
{
lock (values) { values.Add(x.Message); }
return Task.CompletedTask;
});
second.OnMessage(_ => Interlocked.Increment(ref i));
await Task.Delay(100);
await pubsub.PublishAsync(name, "abc");
await Task.Delay(100);
......@@ -119,7 +168,8 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
}
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
Assert.False(first.Completion.IsCompleted, "completed");
Assert.False(second.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAllAsync();
await Task.Delay(100);
......@@ -129,10 +179,13 @@ public async Task ExecuteWithUnsubscribeViaClearAll()
{
Assert.Equal("abc", Assert.Single(values));
}
Assert.Equal(1, Volatile.Read(ref i));
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
Assert.True(first.Completion.IsCompleted, "completed");
Assert.True(second.Completion.IsCompleted, "completed");
AssertCounts(pubsub, name, false, 0, 0);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment