Commit 4b03e8d2 authored by Nick Craver's avatar Nick Craver

Cleanup: RedisSubscriber

parent d6681264
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
partial class ConnectionMultiplexer public partial class ConnectionMultiplexer
{ {
private readonly Dictionary<RedisChannel, Subscription> subscriptions = new Dictionary<RedisChannel, Subscription>(); private readonly Dictionary<RedisChannel, Subscription> subscriptions = new Dictionary<RedisChannel, Subscription>();
...@@ -33,8 +33,7 @@ internal Task AddSubscription(RedisChannel channel, Action<RedisChannel, RedisVa ...@@ -33,8 +33,7 @@ internal Task AddSubscription(RedisChannel channel, Action<RedisChannel, RedisVa
{ {
lock (subscriptions) lock (subscriptions)
{ {
Subscription sub; if (subscriptions.TryGetValue(channel, out Subscription sub))
if (subscriptions.TryGetValue(channel, out sub))
{ {
sub.Add(handler); sub.Add(handler);
} }
...@@ -45,7 +44,6 @@ internal Task AddSubscription(RedisChannel channel, Action<RedisChannel, RedisVa ...@@ -45,7 +44,6 @@ internal Task AddSubscription(RedisChannel channel, Action<RedisChannel, RedisVa
var task = sub.SubscribeToServer(this, channel, flags, asyncState, false); var task = sub.SubscribeToServer(this, channel, flags, asyncState, false);
if (task != null) return task; if (task != null) return task;
} }
} }
} }
return CompletedTask<bool>.Default(asyncState); return CompletedTask<bool>.Default(asyncState);
...@@ -57,8 +55,7 @@ internal ServerEndPoint GetSubscribedServer(RedisChannel channel) ...@@ -57,8 +55,7 @@ internal ServerEndPoint GetSubscribedServer(RedisChannel channel)
{ {
lock (subscriptions) lock (subscriptions)
{ {
Subscription sub; if (subscriptions.TryGetValue(channel, out Subscription sub))
if (subscriptions.TryGetValue(channel, out sub))
{ {
return sub.GetOwner(); return sub.GetOwner();
} }
...@@ -72,8 +69,7 @@ internal void OnMessage(RedisChannel subscription, RedisChannel channel, RedisVa ...@@ -72,8 +69,7 @@ internal void OnMessage(RedisChannel subscription, RedisChannel channel, RedisVa
ICompletable completable = null; ICompletable completable = null;
lock (subscriptions) lock (subscriptions)
{ {
Subscription sub; if (subscriptions.TryGetValue(subscription, out Subscription sub))
if (subscriptions.TryGetValue(subscription, out sub))
{ {
completable = sub.ForInvoke(channel, payload); completable = sub.ForInvoke(channel, payload);
} }
...@@ -101,17 +97,13 @@ internal Task RemoveSubscription(RedisChannel channel, Action<RedisChannel, Redi ...@@ -101,17 +97,13 @@ internal Task RemoveSubscription(RedisChannel channel, Action<RedisChannel, Redi
{ {
lock (subscriptions) lock (subscriptions)
{ {
Subscription sub; if (subscriptions.TryGetValue(channel, out Subscription sub) && sub.Remove(handler))
if (subscriptions.TryGetValue(channel, out sub))
{
if (sub.Remove(handler))
{ {
subscriptions.Remove(channel); subscriptions.Remove(channel);
var task = sub.UnsubscribeFromServer(channel, flags, asyncState, false); var task = sub.UnsubscribeFromServer(channel, flags, asyncState, false);
if (task != null) return task; if (task != null) return task;
} }
} }
}
return CompletedTask<bool>.Default(asyncState); return CompletedTask<bool>.Default(asyncState);
} }
...@@ -133,11 +125,9 @@ internal bool SubscriberConnected(RedisChannel channel = default(RedisChannel)) ...@@ -133,11 +125,9 @@ internal bool SubscriberConnected(RedisChannel channel = default(RedisChannel))
if (server != null) return server.IsConnected; if (server != null) return server.IsConnected;
server = SelectServer(-1, RedisCommand.SUBSCRIBE, CommandFlags.DemandMaster, default(RedisKey)); server = SelectServer(-1, RedisCommand.SUBSCRIBE, CommandFlags.DemandMaster, default(RedisKey));
return server != null && server.IsConnected; return server?.IsConnected == true;
} }
internal long ValidateSubscriptions() internal long ValidateSubscriptions()
{ {
lock (subscriptions) lock (subscriptions)
...@@ -156,14 +146,10 @@ private sealed class Subscription ...@@ -156,14 +146,10 @@ private sealed class Subscription
private Action<RedisChannel, RedisValue> handler; private Action<RedisChannel, RedisValue> handler;
private ServerEndPoint owner; private ServerEndPoint owner;
public Subscription(Action<RedisChannel, RedisValue> value) public Subscription(Action<RedisChannel, RedisValue> value) => handler = value;
{
handler = value; public void Add(Action<RedisChannel, RedisValue> value) => handler += value;
}
public void Add(Action<RedisChannel, RedisValue> value)
{
handler += value;
}
public ICompletable ForInvoke(RedisChannel channel, RedisValue message) public ICompletable ForInvoke(RedisChannel channel, RedisValue message)
{ {
var tmp = handler; var tmp = handler;
...@@ -182,6 +168,7 @@ public bool Remove(Action<RedisChannel, RedisValue> value) ...@@ -182,6 +168,7 @@ public bool Remove(Action<RedisChannel, RedisValue> value)
return (handler -= value) == null; return (handler -= value) == null;
} }
} }
public Task SubscribeToServer(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall) public Task SubscribeToServer(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{ {
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE; var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
...@@ -205,10 +192,8 @@ public Task UnsubscribeFromServer(RedisChannel channel, CommandFlags flags, obje ...@@ -205,10 +192,8 @@ public Task UnsubscribeFromServer(RedisChannel channel, CommandFlags flags, obje
return oldOwner.QueueDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState); return oldOwner.QueueDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState);
} }
internal ServerEndPoint GetOwner() internal ServerEndPoint GetOwner() => Interlocked.CompareExchange(ref owner, null, null);
{
return Interlocked.CompareExchange(ref owner, null, null);
}
internal void Resubscribe(RedisChannel channel, ServerEndPoint server) internal void Resubscribe(RedisChannel channel, ServerEndPoint server)
{ {
if (server != null && Interlocked.CompareExchange(ref owner, server, server) == server) if (server != null && Interlocked.CompareExchange(ref owner, server, server) == server)
...@@ -232,17 +217,12 @@ internal bool Validate(ConnectionMultiplexer multiplexer, RedisChannel channel) ...@@ -232,17 +217,12 @@ internal bool Validate(ConnectionMultiplexer multiplexer, RedisChannel channel)
} }
oldOwner = null; oldOwner = null;
} }
if (oldOwner == null) if (oldOwner == null && SubscribeToServer(multiplexer, channel, CommandFlags.FireAndForget, null, true) != null)
{
if (SubscribeToServer(multiplexer, channel, CommandFlags.FireAndForget, null, true) != null)
{ {
changed = true; changed = true;
} }
}
return changed; return changed;
} }
} }
} }
...@@ -309,12 +289,10 @@ public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> han ...@@ -309,12 +289,10 @@ public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> han
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None) public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
{ {
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
return multiplexer.AddSubscription(channel, handler, flags, asyncState); return multiplexer.AddSubscription(channel, handler, flags, asyncState);
} }
public EndPoint SubscribedEndpoint(RedisChannel channel) public EndPoint SubscribedEndpoint(RedisChannel channel)
{ {
var server = multiplexer.GetSubscribedServer(channel); var server = multiplexer.GetSubscribedServer(channel);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment