Commit 93ab472f authored by Marc Gravell's avatar Marc Gravell

optimize pub/sub for short (<=23 bytes) channel keys

parent c1fd3b0b
...@@ -25,7 +25,7 @@ internal unsafe static CommandBytes TrimToFit(string value) ...@@ -25,7 +25,7 @@ internal unsafe static CommandBytes TrimToFit(string value)
} }
} }
// Uses [n=4] x UInt64 values to store a command payload, // Uses [ChunkLength] x UInt64 values to store a command payload,
// allowing allocation free storage and efficient // allowing allocation free storage and efficient
// equality tests. If you're glancing at this and thinking // equality tests. If you're glancing at this and thinking
// "that's what fixed buffers are for", please see: // "that's what fixed buffers are for", please see:
...@@ -114,7 +114,7 @@ public unsafe CommandBytes(string value) ...@@ -114,7 +114,7 @@ public unsafe CommandBytes(string value)
} }
} }
public unsafe CommandBytes(ReadOnlySpan<byte> value) public unsafe CommandBytes(ReadOnlySpan<byte> value, bool caseInsensitive = true)
{ {
if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed: " + value.Length + " bytes"); if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed: " + value.Length + " bytes");
_0 = _1 = _2 = 0L; _0 = _1 = _2 = 0L;
...@@ -122,10 +122,17 @@ public unsafe CommandBytes(ReadOnlySpan<byte> value) ...@@ -122,10 +122,17 @@ public unsafe CommandBytes(ReadOnlySpan<byte> value)
{ {
byte* bPtr = (byte*)uPtr; byte* bPtr = (byte*)uPtr;
value.CopyTo(new Span<byte>(bPtr + 1, value.Length)); value.CopyTo(new Span<byte>(bPtr + 1, value.Length));
if (caseInsensitive)
{
*bPtr = (byte)UpperCasify(value.Length, bPtr + 1); *bPtr = (byte)UpperCasify(value.Length, bPtr + 1);
} }
else
{
*bPtr = (byte)value.Length;
}
}
} }
public unsafe CommandBytes(ReadOnlySequence<byte> value) public unsafe CommandBytes(ReadOnlySequence<byte> value, bool caseInsensitive = true)
{ {
if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed"); if (value.Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed");
int len = unchecked((int)value.Length); int len = unchecked((int)value.Length);
...@@ -147,8 +154,15 @@ public unsafe CommandBytes(ReadOnlySequence<byte> value) ...@@ -147,8 +154,15 @@ public unsafe CommandBytes(ReadOnlySequence<byte> value)
target = target.Slice(segment.Length); target = target.Slice(segment.Length);
} }
} }
if (caseInsensitive)
{
*bPtr = (byte)UpperCasify(len, bPtr + 1); *bPtr = (byte)UpperCasify(len, bPtr + 1);
} }
else
{
*bPtr = (byte)len;
}
}
} }
private unsafe int UpperCasify(int len, byte* bPtr) private unsafe int UpperCasify(int len, byte* bPtr)
{ {
......
...@@ -49,16 +49,16 @@ public int GetInt32(int index) ...@@ -49,16 +49,16 @@ public int GetInt32(int index)
public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode) public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode)
=> _inner[index].AsRedisChannel(null, mode); => _inner[index].AsRedisChannel(null, mode);
internal bool TryGetCommandBytes(int i, out CommandBytes command) internal bool TryGetCommandBytes(int index, out CommandBytes command, bool caseInsensitive = true)
{ {
var payload = _inner[i].Payload; var payload = _inner[index].Payload;
if (payload.Length > CommandBytes.MaxLength) if (payload.Length > CommandBytes.MaxLength)
{ {
command = default; command = default;
return false; return false;
} }
command = payload.IsEmpty ? default : new CommandBytes(payload); command = payload.IsEmpty ? default : new CommandBytes(payload, caseInsensitive);
return true; return true;
} }
} }
......
...@@ -497,24 +497,66 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request) ...@@ -497,24 +497,66 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
var mode = cmd[0] == (byte)'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal; var mode = cmd[0] == (byte)'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal;
for (int i = 1; i < request.Count; i++) for (int i = 1; i < request.Count; i++)
{ {
var channel = request.GetChannel(i, mode);
int count; int count;
TypedRedisValue channel;
if (request.TryGetCommandBytes(i, out var shortChannel, caseInsensitive: false))
{
lock (_shortSubs)
{
count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd))
{
if (!_shortSubs.TryGetValue(shortChannel, out var tmp))
{
tmp = (request.GetValue(i), new HashSet<RedisClient>());
_shortSubs.Add(shortChannel, tmp);
}
channel = TypedRedisValue.BulkString(tmp.Channel);
if (tmp.Clients.Add(client)) count = client.IncrSubscsriptionCount();
}
else if (s_Unsubscribe.Equals(cmd))
{
if (_shortSubs.TryGetValue(shortChannel, out var tmp))
{
channel = TypedRedisValue.BulkString(tmp.Channel);
if (tmp.Clients.Remove(client))
{
count = client.DecrSubscsriptionCount();
}
}
else
{
channel = TypedRedisValue.BulkString(request.GetValue(i));
}
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
// channel = TypedRedisValue.BulkString((byte[])longChannel);
}
}
else
{
var longChannel = request.GetChannel(i, mode);
lock (_fullSubs) lock (_longSubs)
{ {
count = client.SubscriptionCount; count = client.SubscriptionCount;
if (s_Subscribe.Equals(cmd)) if (s_Subscribe.Equals(cmd))
{ {
if(!_fullSubs.TryGetValue(channel, out var clients)) if (!_longSubs.TryGetValue(longChannel, out var clients))
{ {
clients = new HashSet<RedisClient>(); clients = new HashSet<RedisClient>();
_fullSubs.Add(channel, clients); _longSubs.Add(longChannel, clients);
} }
if (clients.Add(client)) count = client.IncrSubscsriptionCount(); if (clients.Add(client)) count = client.IncrSubscsriptionCount();
} }
else if (s_Unsubscribe.Equals(cmd)) else if (s_Unsubscribe.Equals(cmd))
{ {
if (_fullSubs.TryGetValue(channel, out var clients) if (_longSubs.TryGetValue(longChannel, out var clients)
&& clients.Remove(client)) && clients.Remove(client))
{ {
count = client.DecrSubscsriptionCount(); count = client.DecrSubscsriptionCount();
...@@ -525,9 +567,12 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request) ...@@ -525,9 +567,12 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
reply.Recycle(index); reply.Recycle(index);
return TypedRedisValue.Nil; return TypedRedisValue.Nil;
} }
channel = TypedRedisValue.BulkString((byte[])longChannel);
}
} }
span[index++] = cmdString; span[index++] = cmdString;
span[index++] = TypedRedisValue.BulkString((byte[])channel); span[index++] = channel;
span[index++] = TypedRedisValue.Integer(count); span[index++] = TypedRedisValue.Integer(count);
} }
return reply; return reply;
...@@ -538,11 +583,11 @@ private static readonly CommandBytes ...@@ -538,11 +583,11 @@ private static readonly CommandBytes
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
private static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message"); private static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
private static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload) private static TypedRedisValue CreateBroadcastMessage(RedisValue channel, RedisValue payload)
{ {
var msg = TypedRedisValue.Rent(3, out var span); var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE); span[0] = TypedRedisValue.BulkString(s_MESSAGE);
span[1] = TypedRedisValue.BulkString(channel.Value); span[1] = TypedRedisValue.BulkString(channel);
span[2] = TypedRedisValue.BulkString(payload); span[2] = TypedRedisValue.BulkString(payload);
return msg; return msg;
} }
...@@ -563,7 +608,7 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed ...@@ -563,7 +608,7 @@ private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRed
} }
} }
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload) private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisValue channel, RedisValue payload)
{ {
try try
{ {
...@@ -585,10 +630,19 @@ private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisCha ...@@ -585,10 +630,19 @@ private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisCha
[RedisCommand(3, LockFree = true)] [RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
{ {
var channel = request.GetChannel(1, RedisChannel.PatternMode.Literal); ArraySegment<RedisClient> subscribers;
RedisValue channel;
var subscribers = FilterSubscribers(channel); if (request.TryGetCommandBytes(1, out var shortChannel, caseInsensitive: false))
int count = subscribers.Count; {
subscribers = FilterSubscribers(shortChannel, out channel);
}
else
{
var longChannel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
subscribers = FilterSubscribers(longChannel);
channel = longChannel.Value;
}
var count = subscribers.Count;
if (count != 0) if (count != 0)
{ {
var payload = request.GetValue(2); var payload = request.GetValue(2);
...@@ -597,26 +651,54 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque ...@@ -597,26 +651,54 @@ protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest reque
return TypedRedisValue.Integer(count); return TypedRedisValue.Integer(count);
} }
private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _fullSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>(); private readonly Dictionary<RedisChannel, HashSet<RedisClient>> _longSubs = new Dictionary<RedisChannel, HashSet<RedisClient>>();
private readonly Dictionary<CommandBytes, (RedisValue Channel,HashSet<RedisClient> Clients)> _shortSubs = new Dictionary<CommandBytes, (RedisValue, HashSet<RedisClient>)>();
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel) protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{ {
lock (_fullSubs) lock (_longSubs)
{ {
if (!_fullSubs.TryGetValue(channel, out var clients)) return default; if (!_longSubs.TryGetValue(channel, out var clients)) return default;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count); var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr); clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count); return new ArraySegment<RedisClient>(arr, 0, clients.Count);
} }
} }
private protected ArraySegment<RedisClient> FilterSubscribers(CommandBytes channel, out RedisValue value)
{
lock (_shortSubs)
{
if (!_shortSubs.TryGetValue(channel, out var tmp))
{
value = default;
return default;
}
value = tmp.Channel;
var clients = tmp.Clients;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
private protected ArraySegment<RedisClient> FilterSubscribers(CommandBytes channel)
{
lock (_shortSubs)
{
if (!_shortSubs.TryGetValue(channel, out var tmp)) return default;
var clients = tmp.Clients;
var arr = ArrayPool<RedisClient>.Shared.Rent(clients.Count);
clients.CopyTo(arr);
return new ArraySegment<RedisClient>(arr, 0, clients.Count);
}
}
protected override void OnRemoveClient(RedisClient client) protected override void OnRemoveClient(RedisClient client)
{ {
lock (_fullSubs) lock (_longSubs)
{ {
List<RedisChannel> nowEmpty = null; List<RedisChannel> nowEmpty = null;
foreach(var pair in _fullSubs) foreach(var pair in _longSubs)
{ {
var set = pair.Value; var set = pair.Value;
if(set.Remove(client) && set.Count == 0) if(set.Remove(client) && set.Count == 0)
...@@ -626,7 +708,23 @@ protected override void OnRemoveClient(RedisClient client) ...@@ -626,7 +708,23 @@ protected override void OnRemoveClient(RedisClient client)
} }
if(nowEmpty != null) if(nowEmpty != null)
{ {
foreach (var channel in nowEmpty) _fullSubs.Remove(channel); foreach (var channel in nowEmpty) _longSubs.Remove(channel);
}
}
lock(_shortSubs)
{
List<CommandBytes> nowEmpty = null;
foreach (var pair in _shortSubs)
{
var set = pair.Value.Clients;
if (set.Remove(client) && set.Count == 0)
{
(nowEmpty ?? (nowEmpty = new List<CommandBytes>())).Add(pair.Key);
}
}
if (nowEmpty != null)
{
foreach (var channel in nowEmpty) _shortSubs.Remove(channel);
} }
} }
base.OnRemoveClient(client); base.OnRemoveClient(client);
...@@ -637,16 +735,40 @@ protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest ...@@ -637,16 +735,40 @@ protected virtual TypedRedisValue PubsubNumsub(RedisClient client, RedisRequest
{ {
var reply = TypedRedisValue.Rent((request.Count - 2) * 2, out var span); var reply = TypedRedisValue.Rent((request.Count - 2) * 2, out var span);
int index = 0; int index = 0;
lock (_fullSubs) lock (_longSubs)
{
lock (_shortSubs)
{
for (int i = 2; i < request.Count; i++)
{ {
for(int i = 2; i < request.Count; i++) int count;
TypedRedisValue channel;
if (request.TryGetCommandBytes(i, out var shortChannel, caseInsensitive: false))
{ {
var channel = request.GetChannel(i, RedisChannel.PatternMode.Literal); if (_shortSubs.TryGetValue(shortChannel, out var tmp))
var count = _fullSubs.TryGetValue(channel, out var clients) ? clients.Count : 0; {
span[index++] = TypedRedisValue.BulkString(channel.Value); count = tmp.Clients.Count;
channel = TypedRedisValue.BulkString(tmp.Channel);
}
else
{
count = 0;
channel = TypedRedisValue.BulkString(request.GetValue(i));
}
}
else
{
var longChannel = request.GetChannel(i, RedisChannel.PatternMode.Literal);
count = _longSubs.TryGetValue(longChannel, out var clients) ? clients.Count : 0;
channel = TypedRedisValue.BulkString(longChannel.Value);
}
span[index++] = channel;
span[index++] = TypedRedisValue.Integer(count); span[index++] = TypedRedisValue.Integer(count);
} }
} }
}
return reply; return reply;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment