Commit 93ab472f authored by Marc Gravell's avatar Marc Gravell

optimize pub/sub for short (<=23 bytes) channel keys

parent c1fd3b0b
......@@ -25,7 +25,7 @@ internal unsafe static CommandBytes TrimToFit(string value)
}
}
// Uses [n=4] x UInt64 values to store a command payload,
// Uses [ChunkLength] x UInt64 values to store a command payload,
// allowing allocation free storage and efficient
// equality tests. If you're glancing at this and thinking
// "that's what fixed buffers are for", please see:
......@@ -114,7 +114,7 @@ public unsafe CommandBytes(string value)
}
}
public unsafe CommandBytes(ReadOnlySpan<byte> value)
public unsafe CommandBytes(ReadOnlySpan<byte> value, bool caseInsensitive = true)
{
if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed: " + value.Length + " bytes");
_0 = _1 = _2 = 0L;
......@@ -122,10 +122,17 @@ public unsafe CommandBytes(ReadOnlySpan<byte> value)
{
byte* bPtr = (byte*)uPtr;
value.CopyTo(new Span<byte>(bPtr + 1, value.Length));
*bPtr = (byte)UpperCasify(value.Length, bPtr + 1);
if (caseInsensitive)
{
*bPtr = (byte)UpperCasify(value.Length, bPtr + 1);
}
else
{
*bPtr = (byte)value.Length;
}
}
}
public unsafe CommandBytes(ReadOnlySequence<byte> value)
public unsafe CommandBytes(ReadOnlySequence<byte> value, bool caseInsensitive = true)
{
if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed");
int len = unchecked((int)value.Length);
......@@ -147,7 +154,14 @@ public unsafe CommandBytes(ReadOnlySequence<byte> value)
target = target.Slice(segment.Length);
}
}
*bPtr = (byte)UpperCasify(len, bPtr + 1);
if (caseInsensitive)
{
*bPtr = (byte)UpperCasify(len, bPtr + 1);
}
else
{
*bPtr = (byte)len;
}
}
}
private unsafe int UpperCasify(int len, byte* bPtr)
......
......@@ -49,16 +49,16 @@ public int GetInt32(int index)
public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode)
=> _inner[index].AsRedisChannel(null, mode);
internal bool TryGetCommandBytes(int i, out CommandBytes command)
internal bool TryGetCommandBytes(int index, out CommandBytes command, bool caseInsensitive = true)
{
var payload = _inner[i].Payload;
var payload = _inner[index].Payload;
if (payload.Length > CommandBytes.MaxLength)
{
command = default;
return false;
}
command = payload.IsEmpty ? default : new CommandBytes(payload);
command = payload.IsEmpty ? default : new CommandBytes(payload, caseInsensitive);
return true;
}
}
......
......@@ -497,37 +497,82 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
var mode = cmd[0] == (byte)'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal;
for (int i = 1; i < request.Count; i++)
{
var channel = request.GetChannel(i, mode);
int count;
lock (_fullSubs)
TypedRedisValue channel;
if (request.TryGetCommandBytes(i, out var shortChannel, caseInsensitive: false))
{
count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd))
lock (_shortSubs)
{
if(!_fullSubs.TryGetValue(channel, out var clients))
count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd))
{
clients = new HashSet<RedisClient>();
_fullSubs.Add(channel, clients);
if (!_shortSubs.TryGetValue(shortChannel, out var tmp))
{
tmp = (request.GetValue(i), new HashSet<RedisClient>());
_shortSubs.Add(shortChannel, tmp);
}
channel = TypedRedisValue.BulkString(tmp.Channel);
if (tmp.Clients.Add(client)) count = client.IncrSubscsriptionCount();
}
if (clients.Add(client)) count = client.IncrSubscsriptionCount();
}
else if (s_Unsubscribe.Equals(cmd))
{
if (_fullSubs.TryGetValue(channel, out var clients)
&& clients.Remove(client))
else if (s_Unsubscribe.Equals(cmd))
{
count = client.DecrSubscsriptionCount();
if (_shortSubs.TryGetValue(shortChannel, out var tmp))
{
channel = TypedRedisValue.BulkString(tmp.Channel);
if (tmp.Clients.Remove(client))
{
count = client.DecrSubscsriptionCount();
}
}
else
{
channel = TypedRedisValue.BulkString(request.GetValue(i));
}
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
// channel = TypedRedisValue.BulkString((byte[])longChannel);
}
else
}
else
{
var longChannel = request.GetChannel(i, mode);
lock (_longSubs)
{
reply.Recycle(index);
return TypedRedisValue.Nil;
count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd))
{
if (!_longSubs.TryGetValue(longChannel, out var clients))
{
clients = new HashSet<RedisClient>();
_longSubs.Add(longChannel, clients);
}
if (clients.Add(client)) count = client.IncrSubscsriptionCount();
}
else if (s_Unsubscribe.Equals(cmd))
{
if (_longSubs.TryGetValue(longChannel, out var clients)
&& clients.Remove(client))
{
count = client.DecrSubscsriptionCount();
}
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
channel = TypedRedisValue.BulkString((byte[])longChannel);
}
}
span[index++] = cmdString;
span[index++] = TypedRedisValue.BulkString((byte[])channel);
span[index++] = channel;
span[index++] = TypedRedisValue.Integer(count);
}
return reply;
......@@ -538,11 +583,11 @@ private static readonly CommandBytes
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
private static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
private static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
private static TypedRedisValue CreateBroadcastMessage(RedisValue channel, RedisValue payload)
{
var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE);
span[1] = TypedRedisValue.BulkString(channel.Value);
span[1] = TypedRedisValue.BulkString(channel);
span[2] = TypedRedisValue.BulkString(payload);
return msg;
}
......@@ -563,7 +608,7 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed
}
}
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisValue channel, RedisValue payload)
{
try
{
......@@ -585,10 +630,19 @@ private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisCha
[RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
var subscribers = FilterSubscribers(channel);
int count = subscribers.Count;
ArraySegment<RedisClient> subscribers;
RedisValue channel;
if (request.TryGetCommandBytes(1, out var shortChannel, caseInsensitive: false))
{
subscribers = FilterSubscribers(shortChannel, out channel);
}
else
{
var longChannel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
subscribers = FilterSubscribers(longChannel);
channel = longChannel.Value;
}
var count = subscribers.Count;
if (count != 0)
{
var payload = request.GetValue(2);
......@@ -597,26 +651,54 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque
return TypedRedisValue.Integer(count);
}
private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _fullSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>();
private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _longSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>();
private readonly Dictionary<CommandBytes, (RedisValue Channel,HashSet<RedisClient> Clients)> _shortSubs = new Dictionary<CommandBytes, (RedisValue, HashSet<RedisClient>)>();
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_fullSubs)
lock (_longSubs)
{
if (!_fullSubs.TryGetValue(channel, out var clients)) return default;
if (!_longSubs.TryGetValue(channel, out var clients)) return default;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
private protected ArraySegment<RedisClient> FilterSubscribers(CommandBytes channel, out RedisValue value)
{
lock (_shortSubs)
{
if (!_shortSubs.TryGetValue(channel, out var tmp))
{
value = default;
return default;
}
value = tmp.Channel;
var clients = tmp.Clients;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
private protected ArraySegment<RedisClient> FilterSubscribers(CommandBytes channel)
{
lock (_shortSubs)
{
if (!_shortSubs.TryGetValue(channel, out var tmp)) return default;
var clients = tmp.Clients;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
protected override void OnRemoveClient(RedisClient client)
{
lock (_fullSubs)
lock (_longSubs)
{
List<RedisChannel> nowEmpty = null;
foreach(var pair in _fullSubs)
foreach(var pair in _longSubs)
{
var set = pair.Value;
if(set.Remove(client) && set.Count == 0)
......@@ -626,7 +708,23 @@ protected override void OnRemoveClient(RedisClient client)
}
if(nowEmpty != null)
{
foreach (var channel in nowEmpty) _fullSubs.Remove(channel);
foreach (var channel in nowEmpty) _longSubs.Remove(channel);
}
}
lock(_shortSubs)
{
List<CommandBytes> nowEmpty = null;
foreach (var pair in _shortSubs)
{
var set = pair.Value.Clients;
if (set.Remove(client) && set.Count == 0)
{
(nowEmpty ?? (nowEmpty = new List<CommandBytes>())).Add(pair.Key);
}
}
if (nowEmpty != null)
{
foreach (var channel in nowEmpty) _shortSubs.Remove(channel);
}
}
base.OnRemoveClient(client);
......@@ -637,14 +735,38 @@ protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest
{
var reply = TypedRedisValue.Rent((request.Count - 2) * 2, out var span);
int index = 0;
lock (_fullSubs)
lock (_longSubs)
{
for(int i = 2; i < request.Count; i++)
lock (_shortSubs)
{
var channel = request.GetChannel(i, RedisChannel.PatternMode.Literal);
var count = _fullSubs.TryGetValue(channel, out var clients) ? clients.Count : 0;
span[index++] = TypedRedisValue.BulkString(channel.Value);
span[index++] = TypedRedisValue.Integer(count);
for (int i = 2; i < request.Count; i++)
{
int count;
TypedRedisValue channel;
if (request.TryGetCommandBytes(i, out var shortChannel, caseInsensitive: false))
{
if (_shortSubs.TryGetValue(shortChannel, out var tmp))
{
count = tmp.Clients.Count;
channel = TypedRedisValue.BulkString(tmp.Channel);
}
else
{
count = 0;
channel = TypedRedisValue.BulkString(request.GetValue(i));
}
}
else
{
var longChannel = request.GetChannel(i, RedisChannel.PatternMode.Literal);
count = _longSubs.TryGetValue(longChannel, out var clients) ? clients.Count : 0;
channel = TypedRedisValue.BulkString(longChannel.Value);
}
span[index++] = channel;
span[index++] = TypedRedisValue.Integer(count);
}
}
}
return reply;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment