Commit c1fd3b0b authored by Marc Gravell's avatar Marc Gravell

move channel tracking to the server instead of the client - more efficient for publish

parent c7627a9a
...@@ -186,7 +186,6 @@ public void StringSet() ...@@ -186,7 +186,6 @@ public void StringSet()
} }
} }
/// <summary> /// <summary>
/// Run StringGet lots of times /// Run StringGet lots of times
/// </summary> /// </summary>
......
using System; using System;
using System.Collections.Generic;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
...@@ -18,52 +17,45 @@ internal bool ShouldSkipResponse() ...@@ -18,52 +17,45 @@ internal bool ShouldSkipResponse()
} }
return false; return false;
} }
private HashSet<RedisChannel> _subscripions;
public int SubscriptionCount => _subscripions?.Count ?? 0; [Flags]
internal int Subscribe(RedisChannel channel) private enum ClientFlags
{ {
var subs = _subscripions; None = 0,
if(subs == null) Closed = 1 << 0,
{ HasHadSubscsription = 1 << 1
subs = new HashSet<RedisChannel>(); // but need to watch for compete
subs = Interlocked.CompareExchange(ref _subscripions, subs, null) ?? subs;
}
lock (subs)
{
subs.Add(channel);
return subs.Count;
}
} }
internal int Unsubscribe(RedisChannel channel) private ClientFlags _flags;
private bool HasFlag(ClientFlags flag) => (_flags & flag) != 0;
private void SetFlag(ClientFlags flag, bool value)
{ {
var subs = _subscripions; if (value) _flags |= flag;
if (subs == null) return 0; else _flags &= ~flag;
lock (subs)
{
subs.Remove(channel);
return subs.Count;
}
} }
internal bool IsSubscribed(RedisChannel channel) private int _subscriptionCount;
public bool HasHadSubscsription => HasFlag(ClientFlags.HasHadSubscsription);
public int SubscriptionCount => Thread.VolatileRead(ref _subscriptionCount);
internal int IncrSubscsriptionCount()
{ {
var subs = _subscripions; SetFlag(ClientFlags.HasHadSubscsription, true);
if (subs == null) return false; return Interlocked.Increment(ref _subscriptionCount);
lock (subs)
{
return subs.Contains(channel);
}
} }
internal int DecrSubscsriptionCount() => Interlocked.Decrement(ref _subscriptionCount);
public int Database { get; set; } public int Database { get; set; }
public string Name { get; set; } public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; } internal IDuplexPipe LinkedPipe { get; set; }
public bool Closed { get; internal set; } public bool Closed => HasFlag(ClientFlags.Closed);
public int Id { get; internal set; } public int Id { get; internal set; }
internal void SetClosed() => SetFlag(ClientFlags.Closed, true);
public void Dispose() public void Dispose()
{ {
Closed = true; SetClosed();
var pipe = LinkedPipe; var pipe = LinkedPipe;
LinkedPipe = null; LinkedPipe = null;
if (pipe != null) if (pipe != null)
......
...@@ -499,18 +499,32 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request) ...@@ -499,18 +499,32 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
{ {
var channel = request.GetChannel(i, mode); var channel = request.GetChannel(i, mode);
int count; int count;
if (s_Subscribe.Equals(cmd))
{ lock (_fullSubs)
count = client.Subscribe(channel);
}
else if (s_Unsubscribe.Equals(cmd))
{
count = client.Unsubscribe(channel);
}
else
{ {
reply.Recycle(index); count = client.SubscriptionCount;
return TypedRedisValue.Nil; if (s_Subscribe.Equals(cmd))
{
if(!_fullSubs.TryGetValue(channel, out var clients))
{
clients = new HashSet<RedisClient>();
_fullSubs.Add(channel, clients);
}
if (clients.Add(client)) count = client.IncrSubscsriptionCount();
}
else if (s_Unsubscribe.Equals(cmd))
{
if (_fullSubs.TryGetValue(channel, out var clients)
&& clients.Remove(client))
{
count = client.DecrSubscsriptionCount();
}
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
} }
span[index++] = cmdString; span[index++] = cmdString;
span[index++] = TypedRedisValue.BulkString((byte[])channel); span[index++] = TypedRedisValue.BulkString((byte[])channel);
...@@ -523,8 +537,8 @@ private static readonly CommandBytes ...@@ -523,8 +537,8 @@ private static readonly CommandBytes
s_Unsubscribe = new CommandBytes("unsubscribe"); s_Unsubscribe = new CommandBytes("unsubscribe");
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message"); private static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload) private static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
{ {
var msg = TypedRedisValue.Rent(3, out var span); var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE); span[0] = TypedRedisValue.BulkString(s_MESSAGE);
...@@ -551,14 +565,21 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed ...@@ -551,14 +565,21 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload) private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
{ {
var msg = CreateBroadcastMessage(channel, payload); try
foreach (var sub in clients) {
var msg = CreateBroadcastMessage(channel, payload);
foreach (var sub in clients)
{
await TrySendOutOfBandAsync(sub, msg).ConfigureAwait(false);
}
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
}
catch (Exception ex)
{ {
await TrySendOutOfBandAsync(sub, msg); Debug.WriteLine(ex.Message);
} }
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
} }
[RedisCommand(3, LockFree = true)] [RedisCommand(3, LockFree = true)]
...@@ -576,17 +597,59 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque ...@@ -576,17 +597,59 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque
return TypedRedisValue.Integer(count); return TypedRedisValue.Integer(count);
} }
private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _fullSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>();
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_fullSubs)
{
if (!_fullSubs.TryGetValue(channel, out var clients)) return default;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
protected override void OnRemoveClient(RedisClient client)
{
lock (_fullSubs)
{
List<RedisChannel> nowEmpty = null;
foreach(var pair in _fullSubs)
{
var set = pair.Value;
if(set.Remove(client) && set.Count == 0)
{
(nowEmpty ?? (nowEmpty = new List<RedisChannel>())).Add(pair.Key);
}
}
if(nowEmpty != null)
{
foreach (var channel in nowEmpty) _fullSubs.Remove(channel);
}
}
base.OnRemoveClient(client);
}
[RedisCommand(-3, "pubsub", "numsub", LockFree = true)] [RedisCommand(-3, "pubsub", "numsub", LockFree = true)]
protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest request) protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest request)
{ {
var channel = request.GetChannel(2, RedisChannel.PatternMode.Literal); var reply = TypedRedisValue.Rent((request.Count - 2) * 2, out var span);
var subscribers = FilterSubscribers(channel); int index = 0;
int count = subscribers.Count; lock (_fullSubs)
if (count != 0) ArrayPool<RedisClient>.Shared.Return(subscribers.Array); {
return TypedRedisValue.Integer(count); for(int i = 2; i < request.Count; i++)
{
var channel = request.GetChannel(i, RedisChannel.PatternMode.Literal);
var count = _fullSubs.TryGetValue(channel, out var clients) ? clients.Count : 0;
span[index++] = TypedRedisValue.BulkString(channel.Value);
span[index++] = TypedRedisValue.Integer(count);
}
}
return reply;
} }
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request)
{ {
......
...@@ -205,29 +205,6 @@ public int ClientCount ...@@ -205,29 +205,6 @@ public int ClientCount
get { lock (_clients) { return _clients.Count; } } get { lock (_clients) { return _clients.Count; } }
} }
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_clients)
{
var arr = ArrayPool<RedisClient>.Shared.Rent(_clients.Count);
int count = 0;
foreach(var client in _clients)
{
if (client.IsSubscribed(channel))
arr[count++] = client;
}
if (count == 0)
{
ArrayPool<RedisClient>.Shared.Return(arr);
return default; // Count=0, importantly
}
else
{
return new ArraySegment<RedisClient>(arr, 0, count);
}
}
}
public int TotalClientCount { get; private set; } public int TotalClientCount { get; private set; }
private int _nextId; private int _nextId;
public RedisClient AddClient(IDuplexPipe pipe) public RedisClient AddClient(IDuplexPipe pipe)
...@@ -245,12 +222,16 @@ public RedisClient AddClient(IDuplexPipe pipe) ...@@ -245,12 +222,16 @@ public RedisClient AddClient(IDuplexPipe pipe)
public bool RemoveClient(RedisClient client) public bool RemoveClient(RedisClient client)
{ {
if (client == null) return false; if (client == null) return false;
bool result;
lock (_clients) lock (_clients)
{ {
client.Closed = true; client.SetClosed();
return _clients.Remove(client); result = _clients.Remove(client);
} }
if (result) OnRemoveClient(client);
return result;
} }
protected virtual void OnRemoveClient(RedisClient client) { }
private readonly TaskCompletionSource<ShutdownReason> _shutdown = TaskSource.Create<ShutdownReason>(null, TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource<ShutdownReason> _shutdown = TaskSource.Create<ShutdownReason>(null, TaskCreationOptions.RunContinuationsAsynchronously);
private bool _isShutdown; private bool _isShutdown;
...@@ -349,9 +330,9 @@ void WritePrefix(PipeWriter ooutput, char pprefix) ...@@ -349,9 +330,9 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
bool haveLock = false; bool haveLock = false;
if (client != null && client.SubscriptionCount != 0) if (client != null && client.HasHadSubscsription)
{ {
await client.TakeWriteLockAsync(); await client.TakeWriteLockAsync().ConfigureAwait(false);
haveLock = true; haveLock = true;
} }
try try
......
...@@ -4,57 +4,59 @@ ...@@ -4,57 +4,59 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using StackExchange.Redis; using StackExchange.Redis;
static class Program namespace TestsConsole
{ {
private static int taskCount = 10; internal static class Program
private static int totalRecords = 100000;
static void Main()
{ {
private const int TaskCount = 10;
private const int TotalRecords = 100000;
private static void Main()
{
#if SEV2 #if SEV2
Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies(); Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies();
Console.WriteLine("We loaded the things..."); Console.WriteLine("We loaded the things...");
// Console.ReadLine(); // Console.ReadLine();
#endif #endif
Stopwatch stopwatch = new Stopwatch(); Stopwatch stopwatch = new Stopwatch();
stopwatch.Start(); stopwatch.Start();
var taskList = new List<Task>();
var connection = ConnectionMultiplexer.Connect("127.0.0.1");
for (int i = 0; i < taskCount; i++)
{
var i1 = i;
var task = new Task(() => Run(i1, connection));
task.Start();
taskList.Add(task);
}
Task.WaitAll(taskList.ToArray()); var taskList = new List<Task>();
var connection = ConnectionMultiplexer.Connect("127.0.0.1");
for (int i = 0; i < TaskCount; i++)
{
var i1 = i;
var task = new Task(() => Run(i1, connection));
task.Start();
taskList.Add(task);
}
stopwatch.Stop(); Task.WaitAll(taskList.ToArray());
Console.WriteLine($"Done. {stopwatch.ElapsedMilliseconds}"); stopwatch.Stop();
Console.ReadLine();
}
static void Run(int taskId, ConnectionMultiplexer connection) Console.WriteLine($"Done. {stopwatch.ElapsedMilliseconds}");
{ Console.ReadLine();
Console.WriteLine($"{taskId} Started"); }
var database = connection.GetDatabase(0);
for (int i = 0; i < totalRecords; i++) private static void Run(int taskId, ConnectionMultiplexer connection)
{ {
database.StringSet(i.ToString(), i.ToString()); Console.WriteLine($"{taskId} Started");
} var database = connection.GetDatabase(0);
Console.WriteLine($"{taskId} Insert completed"); for (int i = 0; i < TotalRecords; i++)
{
database.StringSet(i.ToString(), i.ToString());
}
for (int i = 0; i < totalRecords; i++) Console.WriteLine($"{taskId} Insert completed");
{
var result = database.StringGet(i.ToString()); for (int i = 0; i < TotalRecords; i++)
{
var result = database.StringGet(i.ToString());
}
Console.WriteLine($"{taskId} Completed");
} }
Console.WriteLine($"{taskId} Completed");
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment