Commit 3627d330 authored by Marc Gravell's avatar Marc Gravell

implement OnMessage for both sync and async cases; fix handling of all other subscriptions

parent 04a7bf90
...@@ -84,7 +84,7 @@ public async Task PubSubGetAllAnyOrder() ...@@ -84,7 +84,7 @@ public async Task PubSubGetAllAnyOrder()
{ {
var sub = muxer.GetSubscriber(); var sub = muxer.GetSubscriber();
RedisChannel channel = Me(); RedisChannel channel = Me();
const int count = 500000; const int count = 1000;
var syncLock = new object(); var syncLock = new object();
var data = new HashSet<int>(); var data = new HashSet<int>();
...@@ -125,7 +125,7 @@ public async Task PubSubGetAllAnyOrder() ...@@ -125,7 +125,7 @@ public async Task PubSubGetAllAnyOrder()
} }
} }
[FactLongRunning] [Fact]
public async Task PubSubGetAllCorrectOrder() public async Task PubSubGetAllCorrectOrder()
{ {
using (var muxer = GetRemoteConnection(waitForOpen: true, using (var muxer = GetRemoteConnection(waitForOpen: true,
...@@ -133,7 +133,7 @@ public async Task PubSubGetAllCorrectOrder() ...@@ -133,7 +133,7 @@ public async Task PubSubGetAllCorrectOrder()
{ {
var sub = muxer.GetSubscriber(); var sub = muxer.GetSubscriber();
RedisChannel channel = Me(); RedisChannel channel = Me();
const int count = 500000; const int count = 1000;
var syncLock = new object(); var syncLock = new object();
var data = new List<int>(count); var data = new List<int>(count);
...@@ -143,7 +143,7 @@ public async Task PubSubGetAllCorrectOrder() ...@@ -143,7 +143,7 @@ public async Task PubSubGetAllCorrectOrder()
async Task RunLoop() async Task RunLoop()
{ {
while (!subChannel.IsComplete) while (!subChannel.IsCompleted)
{ {
var work = await subChannel.ReadAsync(); var work = await subChannel.ReadAsync();
int i = int.Parse(Encoding.UTF8.GetString(work.Value)); int i = int.Parse(Encoding.UTF8.GetString(work.Value));
...@@ -180,14 +180,137 @@ async Task RunLoop() ...@@ -180,14 +180,137 @@ async Task RunLoop()
} }
} }
Assert.True(subChannel.IsComplete); Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate { await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync();
});
}
}
[Fact]
public async Task PubSubGetAllCorrectOrder_OnMessage_Sync()
{
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
var subChannel = await sub.SubscribeAsync(channel);
subChannel.OnMessage((key, val) =>
{
int i = int.Parse(Encoding.UTF8.GetString(val));
bool pulse = false;
lock (data)
{
data.Add(i);
if (data.Count == count) pulse = true;
if ((data.Count % 10) == 99) Output.WriteLine(data.Count.ToString());
}
if (pulse)
{
lock (syncLock)
{
Monitor.PulseAll(syncLock);
}
}
});
await sub.PingAsync();
lock (syncLock)
{
for (int i = 0; i < count; i++)
{
sub.Publish(channel, i.ToString(), CommandFlags.FireAndForget);
}
if (!Monitor.Wait(syncLock, 20000))
{
throw new TimeoutException("Items: " + data.Count);
}
subChannel.Unsubscribe();
sub.Ping();
muxer.GetDatabase().Ping();
for (int i = 0; i < count; i++)
{
Assert.Equal(i, data[i]);
}
}
Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync(); var final = await subChannel.ReadAsync();
}); });
} }
} }
[Fact]
public async Task PubSubGetAllCorrectOrder_OnMessage_Async()
{
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
var subChannel = await sub.SubscribeAsync(channel);
subChannel.OnMessage((key, val) =>
{
int i = int.Parse(Encoding.UTF8.GetString(val));
bool pulse = false;
lock (data)
{
data.Add(i);
if (data.Count == count) pulse = true;
if ((data.Count % 10) == 99) Output.WriteLine(data.Count.ToString());
}
if (pulse)
{
lock (syncLock)
{
Monitor.PulseAll(syncLock);
}
}
return i % 2 == 0 ? null : Task.CompletedTask;
});
await sub.PingAsync();
lock (syncLock)
{
for (int i = 0; i < count; i++)
{
sub.Publish(channel, i.ToString(), CommandFlags.FireAndForget);
}
if (!Monitor.Wait(syncLock, 20000))
{
throw new TimeoutException("Items: " + data.Count);
}
subChannel.Unsubscribe();
sub.Ping();
muxer.GetDatabase().Ping();
for (int i = 0; i < count; i++)
{
Assert.Equal(i, data[i]);
}
}
Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync();
});
}
}
[Fact] [Fact]
public void TestPublishWithSubscribers() public void TestPublishWithSubscribers()
......
...@@ -28,27 +28,27 @@ internal ChannelMessage(RedisChannel channel, RedisValue value) ...@@ -28,27 +28,27 @@ internal ChannelMessage(RedisChannel channel, RedisValue value)
/// <summary> /// <summary>
/// Represents a message queue of pub/sub notifications /// Represents a message queue of ordered pub/sub notifications
/// </summary> /// </summary>
/// <remarks>To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey)</remarks> /// <remarks>To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey)</remarks>
public sealed class ChannelMessageQueue public sealed class ChannelMessageQueue
{ {
private readonly Channel<ChannelMessage> _channel; private readonly Channel<ChannelMessage> _channel;
private readonly RedisChannel _redisChannel; private readonly RedisChannel _redisChannel;
private ISubscriber _parent; private RedisSubscriber _parent;
/// <summary> /// <summary>
/// Indicates if all messages that will be received have been drained from this channel /// Indicates if all messages that will be received have been drained from this channel
/// </summary> /// </summary>
public bool IsComplete { get; private set; } public bool IsCompleted { get; private set; }
internal ChannelMessageQueue(RedisChannel redisChannel, ISubscriber parent) internal ChannelMessageQueue(RedisChannel redisChannel, RedisSubscriber parent)
{ {
_redisChannel = redisChannel; _redisChannel = redisChannel;
_parent = parent; _parent = parent;
_channel = Channel.CreateUnbounded<ChannelMessage>(s_ChannelOptions); _channel = Channel.CreateUnbounded<ChannelMessage>(s_ChannelOptions);
_channel.Reader.Completion.ContinueWith( _channel.Reader.Completion.ContinueWith(
(t, state) => ((ChannelMessageQueue)state).IsComplete = true, this, TaskContinuationOptions.ExecuteSynchronously); (t, state) => ((ChannelMessageQueue)state).IsCompleted = true, this, TaskContinuationOptions.ExecuteSynchronously);
} }
static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
{ {
...@@ -79,6 +79,79 @@ private void HandleMessage(RedisChannel channel, RedisValue value) ...@@ -79,6 +79,79 @@ private void HandleMessage(RedisChannel channel, RedisValue value)
public ValueTask<ChannelMessage> ReadAsync(CancellationToken cancellationToken = default) public ValueTask<ChannelMessage> ReadAsync(CancellationToken cancellationToken = default)
=> _channel.Reader.ReadAsync(cancellationToken); => _channel.Reader.ReadAsync(cancellationToken);
/// <summary>
/// Attempt to synchronously consume a message from the channel
/// </summary>
public bool TryRead(out ChannelMessage item) => _channel.Reader.TryRead(out item);
private Delegate _onMessageHandler;
private void AssertOnMessage(Delegate handler)
{
if (handler == null) throw new ArgumentNullException(nameof(handler));
if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null)
throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed");
}
/// <summary>
/// Create a message loop that processes messages sequentially
/// </summary>
public void OnMessage(Action<RedisChannel, RedisValue> handler)
{
AssertOnMessage(handler);
ThreadPool.QueueUserWorkItem(
state => ((ChannelMessageQueue)state).OnMessageSyncImpl(), this);
}
private async void OnMessageSyncImpl()
{
var handler = (Action<RedisChannel, RedisValue>)_onMessageHandler;
while (!IsCompleted)
{
ChannelMessage next;
try { if(!TryRead(out next)) next = await ReadAsync(); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)
{
_parent.multiplexer?.OnInternalError(ex);
break;
}
try { handler.Invoke(next.Channel, next.Value); }
catch { } // matches MessageCompletable
}
}
/// <summary>
/// Create a message loop that processes messages sequentially
/// </summary>
public void OnMessage(Func<RedisChannel, RedisValue, Task> handler)
{
AssertOnMessage(handler);
ThreadPool.QueueUserWorkItem(
state => ((ChannelMessageQueue)state).OnMessageAsyncImpl(), this);
}
private async void OnMessageAsyncImpl()
{
var handler = (Func<RedisChannel, RedisValue, Task>)_onMessageHandler;
while (!IsCompleted)
{
ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync(); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)
{
_parent.multiplexer?.OnInternalError(ex);
break;
}
try
{
var task = handler.Invoke(next.Channel, next.Value);
if (task != null) await task;
}
catch { } // matches MessageCompletable
}
}
internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None) internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{ {
var parent = _parent; var parent = _parent;
...@@ -106,7 +179,8 @@ internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler) ...@@ -106,7 +179,8 @@ internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{ {
return handler != null && handler.Target is ChannelMessageQueue return handler != null && handler.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage); && handler.Method.Name == nameof(HandleMessage);
} catch }
catch
{ {
return false; return false;
} }
......
...@@ -49,7 +49,7 @@ public bool TryComplete(bool isAsync) ...@@ -49,7 +49,7 @@ public bool TryComplete(bool isAsync)
} }
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription"); ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription");
} }
return asyncHandler != null; // anything async to do? return asyncHandler == null; // anything async to do?
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment