Commit 9b901af3 authored by mgravell's avatar mgravell

additional for #1101 - make sure the queue gets marked as completed when...

additional for #1101 - make sure the queue gets marked as completed when unsubscribing; simplify the way unsubscribe works for CMQ - remove the ForSyncShutdown concept and just mark completed directly
parent 2876ed04
...@@ -94,14 +94,7 @@ private void HandleMessage(RedisChannel channel, RedisValue value) ...@@ -94,14 +94,7 @@ private void HandleMessage(RedisChannel channel, RedisValue value)
#pragma warning restore RCS1231 // Make parameter ref read-only. #pragma warning restore RCS1231 // Make parameter ref read-only.
{ {
var writer = _queue.Writer; var writer = _queue.Writer;
if (channel.IsNull && value.IsNull) // see ForSyncShutdown writer.TryWrite(new ChannelMessage(this, channel, value));
{
writer.TryComplete();
}
else
{
writer.TryWrite(new ChannelMessage(this, channel, value));
}
} }
/// <summary> /// <summary>
...@@ -212,6 +205,23 @@ private async Task OnMessageAsyncImpl() ...@@ -212,6 +205,23 @@ private async Task OnMessageAsyncImpl()
} }
} }
internal static void MarkCompleted(Action<RedisChannel, RedisValue> handler)
{
if (handler != null)
{
foreach (Action<RedisChannel, RedisValue> sub in handler.GetInvocationList())
{
if (sub.Target is ChannelMessageQueue queue) queue.MarkCompleted();
}
}
}
private void MarkCompleted(Exception error = null)
{
_parent = null;
_queue.Writer.TryComplete(error);
}
internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None) internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{ {
var parent = _parent; var parent = _parent;
......
...@@ -359,8 +359,8 @@ public Task FlushAsync() ...@@ -359,8 +359,8 @@ public Task FlushAsync()
var data = new List<Tuple<string, string>>(); var data = new List<Tuple<string, string>>();
void add(string lk, string sk, string v) void add(string lk, string sk, string v)
{ {
data.Add(Tuple.Create(lk, v)); if (lk != null) data.Add(Tuple.Create(lk, v));
exMessage.Append(", ").Append(sk).Append(": ").Append(v); if (sk != null) exMessage.Append(", ").Append(sk).Append(": ").Append(v);
} }
if (IncludeDetailInExceptions) if (IncludeDetailInExceptions)
......
...@@ -90,36 +90,46 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i ...@@ -90,36 +90,46 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i
internal Task RemoveAllSubscriptions(CommandFlags flags, object asyncState) internal Task RemoveAllSubscriptions(CommandFlags flags, object asyncState)
{ {
Task last = CompletedTask<bool>.Default(asyncState); Task last = null;
lock (subscriptions) lock (subscriptions)
{ {
foreach (var pair in subscriptions) foreach (var pair in subscriptions)
{ {
var msg = pair.Value.ForSyncShutdown(); pair.Value.MarkCompleted();
if (msg != null && !msg.TryComplete(false)) ConnectionMultiplexer.CompleteAsWorker(msg);
pair.Value.Remove(default, null); // when passing null, it wipes both sync+async
var task = pair.Value.UnsubscribeFromServer(pair.Key, flags, asyncState, false); var task = pair.Value.UnsubscribeFromServer(pair.Key, flags, asyncState, false);
if (task != null) last = task; if (task != null) last = task;
} }
subscriptions.Clear(); subscriptions.Clear();
} }
return last; return last ?? CompletedTask<bool>.Default(asyncState);
} }
internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags, object asyncState) internal Task RemoveSubscription(in RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags, object asyncState)
{ {
Task task = null;
lock (subscriptions) lock (subscriptions)
{ {
bool asAsync = !ChannelMessageQueue.IsOneOf(handler); if (subscriptions.TryGetValue(channel, out Subscription sub))
if (subscriptions.TryGetValue(channel, out Subscription sub) && sub.Remove(asAsync, handler))
{ {
subscriptions.Remove(channel); bool remove;
var task = sub.UnsubscribeFromServer(channel, flags, asyncState, false); if (handler == null) // blanket wipe
if (task != null) return task; {
sub.MarkCompleted();
remove = true;
}
else
{
bool asAsync = !ChannelMessageQueue.IsOneOf(handler);
remove = sub.Remove(asAsync, handler);
}
if (remove)
{
subscriptions.Remove(channel);
task = sub.UnsubscribeFromServer(channel, flags, asyncState, false);
}
} }
} }
return CompletedTask<bool>.Default(asyncState); return task ?? CompletedTask<bool>.Default(asyncState);
} }
internal void ResendSubscriptions(ServerEndPoint server) internal void ResendSubscriptions(ServerEndPoint server)
...@@ -173,11 +183,6 @@ public void Add(bool asAsync, Action<RedisChannel, RedisValue> value) ...@@ -173,11 +183,6 @@ public void Add(bool asAsync, Action<RedisChannel, RedisValue> value)
else _syncHandler += value; else _syncHandler += value;
} }
public ICompletable ForSyncShutdown()
{
var syncHandler = _syncHandler;
return syncHandler == null ? null : new MessageCompletable(default, default, syncHandler, null);
}
public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message) public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message)
{ {
var syncHandler = _syncHandler; var syncHandler = _syncHandler;
...@@ -185,14 +190,17 @@ public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message) ...@@ -185,14 +190,17 @@ public ICompletable ForInvoke(in RedisChannel channel, in RedisValue message)
return (syncHandler == null && asyncHandler == null) ? null : new MessageCompletable(channel, message, syncHandler, asyncHandler); return (syncHandler == null && asyncHandler == null) ? null : new MessageCompletable(channel, message, syncHandler, asyncHandler);
} }
internal void MarkCompleted()
{
_asyncHandler = null;
var oldSync = _syncHandler;
_syncHandler = null;
ChannelMessageQueue.MarkCompleted(oldSync);
}
public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value) public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value)
{ {
if (value == null) if (value != null)
{ // treat as blanket wipe
_asyncHandler = null;
_syncHandler = null;
}
else
{ {
if (asAsync) _asyncHandler -= value; if (asAsync) _asyncHandler -= value;
else _syncHandler -= value; else _syncHandler -= value;
......
...@@ -23,7 +23,7 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -23,7 +23,7 @@ public async Task ExecuteWithUnsubscribeViaChannel()
List<string> values = new List<string>(); List<string> values = new List<string>();
channel.OnMessage(x => channel.OnMessage(x =>
{ {
lock(values) { values.Add(x.Message); } lock (values) { values.Add(x.Message); }
return Task.CompletedTask; return Task.CompletedTask;
}); });
await Task.Delay(100); await Task.Delay(100);
...@@ -35,6 +35,7 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -35,6 +35,7 @@ public async Task ExecuteWithUnsubscribeViaChannel()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
await channel.UnsubscribeAsync(); await channel.UnsubscribeAsync();
await Task.Delay(100); await Task.Delay(100);
...@@ -47,6 +48,7 @@ public async Task ExecuteWithUnsubscribeViaChannel() ...@@ -47,6 +48,7 @@ public async Task ExecuteWithUnsubscribeViaChannel()
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
} }
} }
...@@ -75,6 +77,7 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -75,6 +77,7 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
} }
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs); Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAsync(name); await pubsub.UnsubscribeAsync(name);
await Task.Delay(100); await Task.Delay(100);
...@@ -87,6 +90,49 @@ public async Task ExecuteWithUnsubscribeViaSubscriber() ...@@ -87,6 +90,49 @@ public async Task ExecuteWithUnsubscribeViaSubscriber()
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name); subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs); Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
}
}
[Fact]
public async Task ExecuteWithUnsubscribeViaClearAll()
{
using (var muxer = Create())
{
RedisChannel name = Me();
var pubsub = muxer.GetSubscriber();
// subscribe and check we get data
var channel = await pubsub.SubscribeAsync(name);
List<string> values = new List<string>();
channel.OnMessage(x =>
{
lock (values) { values.Add(x.Message); }
return Task.CompletedTask;
});
await Task.Delay(100);
await pubsub.PublishAsync(name, "abc");
await Task.Delay(100);
lock (values)
{
Assert.Equal("abc", Assert.Single(values));
}
var subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(1, subs);
Assert.False(channel.Completion.IsCompleted, "completed");
await pubsub.UnsubscribeAllAsync();
await Task.Delay(100);
await pubsub.PublishAsync(name, "def");
await Task.Delay(100);
lock (values)
{
Assert.Equal("abc", Assert.Single(values));
}
subs = muxer.GetServer(muxer.GetEndPoints().Single()).SubscriptionSubscriberCount(name);
Assert.Equal(0, subs);
Assert.True(channel.Completion.IsCompleted, "completed");
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment