Commit c1fd3b0b authored by Marc Gravell's avatar Marc Gravell

move channel tracking to the server instead of the client - more efficient for publish

parent c7627a9a
......@@ -186,7 +186,6 @@ public void StringSet()
}
}
/// <summary>
/// Run StringGet lots of times
/// </summary>
......
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
......@@ -18,52 +17,45 @@ internal bool ShouldSkipResponse()
}
return false;
}
private HashSet<RedisChannel> _subscripions;
public int SubscriptionCount => _subscripions?.Count ?? 0;
internal int Subscribe(RedisChannel channel)
{
var subs = _subscripions;
if(subs == null)
{
subs = new HashSet<RedisChannel>(); // but need to watch for compete
subs = Interlocked.CompareExchange(ref _subscripions, subs, null) ?? subs;
}
lock (subs)
[Flags]
private enum ClientFlags
{
subs.Add(channel);
return subs.Count;
}
None = 0,
Closed = 1 << 0,
HasHadSubscsription = 1 << 1
}
internal int Unsubscribe(RedisChannel channel)
{
var subs = _subscripions;
if (subs == null) return 0;
lock (subs)
private ClientFlags _flags;
private bool HasFlag(ClientFlags flag) => (_flags & flag) != 0;
private void SetFlag(ClientFlags flag, bool value)
{
subs.Remove(channel);
return subs.Count;
}
if (value) _flags |= flag;
else _flags &= ~flag;
}
internal bool IsSubscribed(RedisChannel channel)
{
var subs = _subscripions;
if (subs == null) return false;
lock (subs)
private int _subscriptionCount;
public bool HasHadSubscsription => HasFlag(ClientFlags.HasHadSubscsription);
public int SubscriptionCount => Thread.VolatileRead(ref _subscriptionCount);
internal int IncrSubscsriptionCount()
{
return subs.Contains(channel);
}
SetFlag(ClientFlags.HasHadSubscsription, true);
return Interlocked.Increment(ref _subscriptionCount);
}
internal int DecrSubscsriptionCount() => Interlocked.Decrement(ref _subscriptionCount);
public int Database { get; set; }
public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; }
public bool Closed { get; internal set; }
public bool Closed => HasFlag(ClientFlags.Closed);
public int Id { get; internal set; }
internal void SetClosed() => SetFlag(ClientFlags.Closed, true);
public void Dispose()
{
Closed = true;
SetClosed();
var pipe = LinkedPipe;
LinkedPipe = null;
if (pipe != null)
......
......@@ -499,19 +499,33 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(i, mode);
int count;
lock (_fullSubs)
{
count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd))
{
count = client.Subscribe(channel);
if(!_fullSubs.TryGetValue(channel, out var clients))
{
clients = new HashSet<RedisClient>();
_fullSubs.Add(channel, clients);
}
if (clients.Add(client)) count = client.IncrSubscsriptionCount();
}
else if (s_Unsubscribe.Equals(cmd))
{
count = client.Unsubscribe(channel);
if (_fullSubs.TryGetValue(channel, out var clients)
&& clients.Remove(client))
{
count = client.DecrSubscsriptionCount();
}
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
}
span[index++] = cmdString;
span[index++] = TypedRedisValue.BulkString((byte[])channel);
span[index++] = TypedRedisValue.Integer(count);
......@@ -523,8 +537,8 @@ private static readonly CommandBytes
s_Unsubscribe = new CommandBytes("unsubscribe");
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
private static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
private static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
{
var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE);
......@@ -550,16 +564,23 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed
}
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
{
try
{
var msg = CreateBroadcastMessage(channel, payload);
foreach (var sub in clients)
{
await TrySendOutOfBandAsync(sub, msg);
await TrySendOutOfBandAsync(sub, msg).ConfigureAwait(false);
}
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
}
catch (Exception ex)
{
Debug.WriteLine(ex.Message);
}
}
[RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
......@@ -576,16 +597,58 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque
return TypedRedisValue.Integer(count);
}
private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _fullSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>();
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_fullSubs)
{
if (!_fullSubs.TryGetValue(channel, out var clients)) return default;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
protected override void OnRemoveClient(RedisClient client)
{
lock (_fullSubs)
{
List<RedisChannel> nowEmpty = null;
foreach(var pair in _fullSubs)
{
var set = pair.Value;
if(set.Remove(client) && set.Count == 0)
{
(nowEmpty ?? (nowEmpty = new List<RedisChannel>())).Add(pair.Key);
}
}
if(nowEmpty != null)
{
foreach (var channel in nowEmpty) _fullSubs.Remove(channel);
}
}
base.OnRemoveClient(client);
}
[RedisCommand(-3, "pubsub", "numsub", LockFree = true)]
protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(2, RedisChannel.PatternMode.Literal);
var subscribers = FilterSubscribers(channel);
int count = subscribers.Count;
if (count != 0) ArrayPool<RedisClient>.Shared.Return(subscribers.Array);
return TypedRedisValue.Integer(count);
var reply = TypedRedisValue.Rent((request.Count - 2) * 2, out var span);
int index = 0;
lock (_fullSubs)
{
for(int i = 2; i < request.Count; i++)
{
var channel = request.GetChannel(i, RedisChannel.PatternMode.Literal);
var count = _fullSubs.TryGetValue(channel, out var clients) ? clients.Count : 0;
span[index++] = TypedRedisValue.BulkString(channel.Value);
span[index++] = TypedRedisValue.Integer(count);
}
}
return reply;
}
[RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request)
......
......@@ -205,29 +205,6 @@ public int ClientCount
get { lock (_clients) { return _clients.Count; } }
}
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_clients)
{
var arr = ArrayPool<RedisClient>.Shared.Rent(_clients.Count);
int count = 0;
foreach(var client in _clients)
{
if (client.IsSubscribed(channel))
arr[count++] = client;
}
if (count == 0)
{
ArrayPool<RedisClient>.Shared.Return(arr);
return default; // Count=0, importantly
}
else
{
return new ArraySegment<RedisClient>(arr, 0, count);
}
}
}
public int TotalClientCount { get; private set; }
private int _nextId;
public RedisClient AddClient(IDuplexPipe pipe)
......@@ -245,12 +222,16 @@ public RedisClient AddClient(IDuplexPipe pipe)
public bool RemoveClient(RedisClient client)
{
if (client == null) return false;
bool result;
lock (_clients)
{
client.Closed = true;
return _clients.Remove(client);
client.SetClosed();
result = _clients.Remove(client);
}
if (result) OnRemoveClient(client);
return result;
}
protected virtual void OnRemoveClient(RedisClient client) { }
private readonly TaskCompletionSource<ShutdownReason> _shutdown = TaskSource.Create<ShutdownReason>(null, TaskCreationOptions.RunContinuationsAsynchronously);
private bool _isShutdown;
......@@ -349,9 +330,9 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
bool haveLock = false;
if (client != null && client.SubscriptionCount != 0)
if (client != null && client.HasHadSubscsription)
{
await client.TakeWriteLockAsync();
await client.TakeWriteLockAsync().ConfigureAwait(false);
haveLock = true;
}
try
......
......@@ -4,14 +4,15 @@
using System.Threading.Tasks;
using StackExchange.Redis;
static class Program
namespace TestsConsole
{
private static int taskCount = 10;
private static int totalRecords = 100000;
static void Main()
internal static class Program
{
private const int TaskCount = 10;
private const int TotalRecords = 100000;
private static void Main()
{
#if SEV2
Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies();
Console.WriteLine("We loaded the things...");
......@@ -23,7 +24,7 @@ static void Main()
var taskList = new List<Task>();
var connection = ConnectionMultiplexer.Connect("127.0.0.1");
for (int i = 0; i < taskCount; i++)
for (int i = 0; i < TaskCount; i++)
{
var i1 = i;
var task = new Task(() => Run(i1, connection));
......@@ -39,22 +40,23 @@ static void Main()
Console.ReadLine();
}
static void Run(int taskId, ConnectionMultiplexer connection)
private static void Run(int taskId, ConnectionMultiplexer connection)
{
Console.WriteLine($"{taskId} Started");
var database = connection.GetDatabase(0);
for (int i = 0; i < totalRecords; i++)
for (int i = 0; i < TotalRecords; i++)
{
database.StringSet(i.ToString(), i.ToString());
}
Console.WriteLine($"{taskId} Insert completed");
for (int i = 0; i < totalRecords; i++)
for (int i = 0; i < TotalRecords; i++)
{
var result = database.StringGet(i.ToString());
}
Console.WriteLine($"{taskId} Completed");
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment