Commit f82093d8 authored by Marc Gravell's avatar Marc Gravell

implement pub/sub

parent 68af2f79
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server namespace StackExchange.Redis.Server
{ {
...@@ -20,16 +22,39 @@ internal bool ShouldSkipResponse() ...@@ -20,16 +22,39 @@ internal bool ShouldSkipResponse()
public int SubscriptionCount => _subscripions?.Count ?? 0; public int SubscriptionCount => _subscripions?.Count ?? 0;
internal int Subscribe(RedisChannel channel) internal int Subscribe(RedisChannel channel)
{ {
if (_subscripions == null) _subscripions = new HashSet<RedisChannel>(); var subs = _subscripions;
_subscripions.Add(channel); if(subs == null)
return _subscripions.Count; {
subs = new HashSet<RedisChannel>(); // but need to watch for compete
subs = Interlocked.CompareExchange(ref _subscripions, subs, null) ?? subs;
}
lock (subs)
{
subs.Add(channel);
return subs.Count;
}
} }
internal int Unsubscribe(RedisChannel channel) internal int Unsubscribe(RedisChannel channel)
{ {
if (_subscripions == null) return 0; var subs = _subscripions;
_subscripions.Remove(channel); if (subs == null) return 0;
return _subscripions.Count; lock (subs)
{
subs.Remove(channel);
return subs.Count;
}
} }
internal bool IsSubscribed(RedisChannel channel)
{
var subs = _subscripions;
if (subs == null) return false;
lock (subs)
{
return subs.Contains(channel);
}
}
public int Database { get; set; } public int Database { get; set; }
public string Name { get; set; } public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; } internal IDuplexPipe LinkedPipe { get; set; }
...@@ -50,5 +75,9 @@ public void Dispose() ...@@ -50,5 +75,9 @@ public void Dispose()
if (pipe is IDisposable d) try { d.Dispose(); } catch { } if (pipe is IDisposable d) try { d.Dispose(); } catch { }
} }
} }
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
internal Task TakeWriteLockAsync() => _writeLock.WaitAsync();
internal void ReleasseWriteLock() => _writeLock.Release();
} }
} }
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Text; using System.Text;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server namespace StackExchange.Redis.Server
{ {
...@@ -435,7 +437,20 @@ protected virtual TypedRedisValue Mset(RedisClient client, RedisRequest request) ...@@ -435,7 +437,20 @@ protected virtual TypedRedisValue Mset(RedisClient client, RedisRequest request)
} }
[RedisCommand(-1, LockFree = true, MaxArgs = 2)] [RedisCommand(-1, LockFree = true, MaxArgs = 2)]
protected virtual TypedRedisValue Ping(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Ping(RedisClient client, RedisRequest request)
=> TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1)); {
if (client.SubscriptionCount == 0)
{
return TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1));
}
else
{
// strictly speaking this is a >=3.0 feature
var pong = TypedRedisValue.Rent(2, out var span);
span[0] = TypedRedisValue.BulkString("pong");
span[1] = TypedRedisValue.BulkString(request.Count == 1 ? RedisValue.EmptyString : request.GetValue(1));
return pong;
}
}
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Quit(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Quit(RedisClient client, RedisRequest request)
...@@ -465,10 +480,11 @@ protected virtual TypedRedisValue Select(RedisClient client, RedisRequest reques ...@@ -465,10 +480,11 @@ protected virtual TypedRedisValue Select(RedisClient client, RedisRequest reques
return TypedRedisValue.OK; return TypedRedisValue.OK;
} }
[RedisCommand(-2)] [RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Subscribe(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Subscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request); => SubscribeImpl(client, request);
[RedisCommand(-2)]
[RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Unsubscribe(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Unsubscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request); => SubscribeImpl(client, request);
...@@ -507,6 +523,66 @@ private static readonly CommandBytes ...@@ -507,6 +523,66 @@ private static readonly CommandBytes
s_Unsubscribe = new CommandBytes("unsubscribe"); s_Unsubscribe = new CommandBytes("unsubscribe");
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
{
var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE);
span[1] = TypedRedisValue.BulkString(channel.Value);
span[2] = TypedRedisValue.BulkString(payload);
return msg;
}
private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRedisValue value)
{
try
{
var output = client?.LinkedPipe?.Output;
if (output == null)
{
Console.WriteLine("No pipe");
return false;
}
await WriteResponseAsync(client, output, value);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
return false;
}
}
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
{
var msg = CreateBroadcastMessage(channel, payload);
foreach (var sub in clients)
{
await TrySendOutOfBandAsync(sub, msg);
}
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
}
[RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
var subscribers = FilterSubscribers(channel);
int count = subscribers.Count;
if (count != 0)
{
var payload = request.GetValue(2);
Task.Run(() => BackgroundPublish(subscribers, channel, payload));
}
return TypedRedisValue.Integer(count);
}
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request)
{ {
......
...@@ -189,17 +189,41 @@ internal int NetArity() ...@@ -189,17 +189,41 @@ internal int NetArity()
// for extensibility, so that a subclass can get their own client type // for extensibility, so that a subclass can get their own client type
// to be used via ListenForConnections // to be used via ListenForConnections
public virtual RedisClient CreateClient() => new RedisClient(); public virtual RedisClient CreateClient(IDuplexPipe pipe) => new RedisClient { LinkedPipe = pipe };
public int ClientCount public int ClientCount
{ {
get { lock (_clients) { return _clients.Count; } } get { lock (_clients) { return _clients.Count; } }
} }
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_clients)
{
var arr = ArrayPool<RedisClient>.Shared.Rent(_clients.Count);
int count = 0;
foreach(var client in _clients)
{
if (client.IsSubscribed(channel))
arr[count++] = client;
}
if (count == 0)
{
ArrayPool<RedisClient>.Shared.Return(arr);
return default; // Count=0, importantly
}
else
{
return new ArraySegment<RedisClient>(arr, 0, count);
}
}
}
public int TotalClientCount { get; private set; } public int TotalClientCount { get; private set; }
private int _nextId; private int _nextId;
public RedisClient AddClient() public RedisClient AddClient(IDuplexPipe pipe)
{ {
var client = CreateClient(); var client = CreateClient(pipe);
lock (_clients) lock (_clients)
{ {
ThrowIfShutdown(); ThrowIfShutdown();
...@@ -250,7 +274,7 @@ public async Task RunClientAsync(IDuplexPipe pipe) ...@@ -250,7 +274,7 @@ public async Task RunClientAsync(IDuplexPipe pipe)
RedisClient client = null; RedisClient client = null;
try try
{ {
client = AddClient(); client = AddClient(pipe);
while (!client.Closed) while (!client.Closed)
{ {
var readResult = await pipe.Input.ReadAsync().ConfigureAwait(false); var readResult = await pipe.Input.ReadAsync().ConfigureAwait(false);
...@@ -314,54 +338,69 @@ void WritePrefix(PipeWriter ooutput, char pprefix) ...@@ -314,54 +338,69 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request) if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
char prefix;
switch (value.Type) if (client != null)
{ {
case ResultType.Integer: await client.TakeWriteLockAsync();
PhysicalConnection.WriteInteger(output, (long)value.AsRedisValue()); }
break; try
case ResultType.Error: {
prefix = '-'; char prefix;
goto BasicMessage; switch (value.Type)
case ResultType.SimpleString: {
prefix = '+'; case ResultType.Integer:
BasicMessage: PhysicalConnection.WriteInteger(output, (long)value.AsRedisValue());
WritePrefix(output, prefix); break;
var val = (string)value.AsRedisValue(); case ResultType.Error:
var expectedLength = Encoding.UTF8.GetByteCount(val); prefix = '-';
PhysicalConnection.WriteRaw(output, val, expectedLength); goto BasicMessage;
PhysicalConnection.WriteCrlf(output); case ResultType.SimpleString:
break; prefix = '+';
case ResultType.BulkString: BasicMessage:
PhysicalConnection.WriteBulkString(value.AsRedisValue(), output); WritePrefix(output, prefix);
break; var val = (string)value.AsRedisValue();
case ResultType.MultiBulk: var expectedLength = Encoding.UTF8.GetByteCount(val);
if (value.IsNullArray) PhysicalConnection.WriteRaw(output, val, expectedLength);
{ PhysicalConnection.WriteCrlf(output);
PhysicalConnection.WriteMultiBulkHeader(output, -1); break;
} case ResultType.BulkString:
else PhysicalConnection.WriteBulkString(value.AsRedisValue(), output);
{ break;
var segment = value.Segment; case ResultType.MultiBulk:
PhysicalConnection.WriteMultiBulkHeader(output, segment.Count); if (value.IsNullArray)
var arr = segment.Array;
int offset = segment.Offset;
for (int i = 0; i < segment.Count; i++)
{ {
var item = arr[offset++]; PhysicalConnection.WriteMultiBulkHeader(output, -1);
if (item.IsNil)
throw new InvalidOperationException("Array element cannot be nil, index " + i);
// note: don't pass client down; this would impact SkipReplies
await WriteResponseAsync(null, output, item);
} }
} else
break; {
default: var segment = value.Segment;
throw new InvalidOperationException( PhysicalConnection.WriteMultiBulkHeader(output, segment.Count);
"Unexpected result type: " + value.Type); var arr = segment.Array;
int offset = segment.Offset;
for (int i = 0; i < segment.Count; i++)
{
var item = arr[offset++];
if (item.IsNil)
throw new InvalidOperationException("Array element cannot be nil, index " + i);
// note: don't pass client down; this would impact SkipReplies
await WriteResponseAsync(null, output, item);
}
}
break;
default:
throw new InvalidOperationException(
"Unexpected result type: " + value.Type);
}
await output.FlushAsync().ConfigureAwait(false);
}
finally
{
if (client != null)
{
client.ReleasseWriteLock();
}
} }
await output.FlushAsync().ConfigureAwait(false);
} }
public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request) public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment