Commit f82093d8 authored by Marc Gravell's avatar Marc Gravell

implement pub/sub

parent 68af2f79
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server
{
......@@ -20,16 +22,39 @@ internal bool ShouldSkipResponse()
public int SubscriptionCount => _subscripions?.Count ?? 0;
internal int Subscribe(RedisChannel channel)
{
if (_subscripions == null) _subscripions = new HashSet<RedisChannel>();
_subscripions.Add(channel);
return _subscripions.Count;
var subs = _subscripions;
if(subs == null)
{
subs = new HashSet<RedisChannel>(); // but need to watch for compete
subs = Interlocked.CompareExchange(ref _subscripions, subs, null) ?? subs;
}
lock (subs)
{
subs.Add(channel);
return subs.Count;
}
}
internal int Unsubscribe(RedisChannel channel)
{
if (_subscripions == null) return 0;
_subscripions.Remove(channel);
return _subscripions.Count;
var subs = _subscripions;
if (subs == null) return 0;
lock (subs)
{
subs.Remove(channel);
return subs.Count;
}
}
internal bool IsSubscribed(RedisChannel channel)
{
var subs = _subscripions;
if (subs == null) return false;
lock (subs)
{
return subs.Contains(channel);
}
}
public int Database { get; set; }
public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; }
......@@ -50,5 +75,9 @@ public void Dispose()
if (pipe is IDisposable d) try { d.Dispose(); } catch { }
}
}
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
internal Task TakeWriteLockAsync() => _writeLock.WaitAsync();
internal void ReleasseWriteLock() => _writeLock.Release();
}
}
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server
{
......@@ -435,7 +437,20 @@ protected virtual TypedRedisValue Mset(RedisClient client, RedisRequest request)
}
[RedisCommand(-1, LockFree = true, MaxArgs = 2)]
protected virtual TypedRedisValue Ping(RedisClient client, RedisRequest request)
=> TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1));
{
if (client.SubscriptionCount == 0)
{
return TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1));
}
else
{
// strictly speaking this is a >=3.0 feature
var pong = TypedRedisValue.Rent(2, out var span);
span[0] = TypedRedisValue.BulkString("pong");
span[1] = TypedRedisValue.BulkString(request.Count == 1 ? RedisValue.EmptyString : request.GetValue(1));
return pong;
}
}
[RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Quit(RedisClient client, RedisRequest request)
......@@ -465,10 +480,11 @@ protected virtual TypedRedisValue Select(RedisClient client, RedisRequest reques
return TypedRedisValue.OK;
}
[RedisCommand(-2)]
[RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Subscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request);
[RedisCommand(-2)]
[RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Unsubscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request);
......@@ -507,6 +523,66 @@ private static readonly CommandBytes
s_Unsubscribe = new CommandBytes("unsubscribe");
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
{
var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE);
span[1] = TypedRedisValue.BulkString(channel.Value);
span[2] = TypedRedisValue.BulkString(payload);
return msg;
}
private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRedisValue value)
{
try
{
var output = client?.LinkedPipe?.Output;
if (output == null)
{
Console.WriteLine("No pipe");
return false;
}
await WriteResponseAsync(client, output, value);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
return false;
}
}
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
{
var msg = CreateBroadcastMessage(channel, payload);
foreach (var sub in clients)
{
await TrySendOutOfBandAsync(sub, msg);
}
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
}
[RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
var subscribers = FilterSubscribers(channel);
int count = subscribers.Count;
if (count != 0)
{
var payload = request.GetValue(2);
Task.Run(() => BackgroundPublish(subscribers, channel, payload));
}
return TypedRedisValue.Integer(count);
}
[RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request)
{
......
......@@ -189,17 +189,41 @@ internal int NetArity()
// for extensibility, so that a subclass can get their own client type
// to be used via ListenForConnections
public virtual RedisClient CreateClient() => new RedisClient();
public virtual RedisClient CreateClient(IDuplexPipe pipe) => new RedisClient { LinkedPipe = pipe };
public int ClientCount
{
get { lock (_clients) { return _clients.Count; } }
}
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_clients)
{
var arr = ArrayPool<RedisClient>.Shared.Rent(_clients.Count);
int count = 0;
foreach(var client in _clients)
{
if (client.IsSubscribed(channel))
arr[count++] = client;
}
if (count == 0)
{
ArrayPool<RedisClient>.Shared.Return(arr);
return default; // Count=0, importantly
}
else
{
return new ArraySegment<RedisClient>(arr, 0, count);
}
}
}
public int TotalClientCount { get; private set; }
private int _nextId;
public RedisClient AddClient()
public RedisClient AddClient(IDuplexPipe pipe)
{
var client = CreateClient();
var client = CreateClient(pipe);
lock (_clients)
{
ThrowIfShutdown();
......@@ -250,7 +274,7 @@ public async Task RunClientAsync(IDuplexPipe pipe)
RedisClient client = null;
try
{
client = AddClient();
client = AddClient(pipe);
while (!client.Closed)
{
var readResult = await pipe.Input.ReadAsync().ConfigureAwait(false);
......@@ -314,54 +338,69 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
char prefix;
switch (value.Type)
if (client != null)
{
case ResultType.Integer:
PhysicalConnection.WriteInteger(output, (long)value.AsRedisValue());
break;
case ResultType.Error:
prefix = '-';
goto BasicMessage;
case ResultType.SimpleString:
prefix = '+';
BasicMessage:
WritePrefix(output, prefix);
var val = (string)value.AsRedisValue();
var expectedLength = Encoding.UTF8.GetByteCount(val);
PhysicalConnection.WriteRaw(output, val, expectedLength);
PhysicalConnection.WriteCrlf(output);
break;
case ResultType.BulkString:
PhysicalConnection.WriteBulkString(value.AsRedisValue(), output);
break;
case ResultType.MultiBulk:
if (value.IsNullArray)
{
PhysicalConnection.WriteMultiBulkHeader(output, -1);
}
else
{
var segment = value.Segment;
PhysicalConnection.WriteMultiBulkHeader(output, segment.Count);
var arr = segment.Array;
int offset = segment.Offset;
for (int i = 0; i < segment.Count; i++)
await client.TakeWriteLockAsync();
}
try
{
char prefix;
switch (value.Type)
{
case ResultType.Integer:
PhysicalConnection.WriteInteger(output, (long)value.AsRedisValue());
break;
case ResultType.Error:
prefix = '-';
goto BasicMessage;
case ResultType.SimpleString:
prefix = '+';
BasicMessage:
WritePrefix(output, prefix);
var val = (string)value.AsRedisValue();
var expectedLength = Encoding.UTF8.GetByteCount(val);
PhysicalConnection.WriteRaw(output, val, expectedLength);
PhysicalConnection.WriteCrlf(output);
break;
case ResultType.BulkString:
PhysicalConnection.WriteBulkString(value.AsRedisValue(), output);
break;
case ResultType.MultiBulk:
if (value.IsNullArray)
{
var item = arr[offset++];
if (item.IsNil)
throw new InvalidOperationException("Array element cannot be nil, index " + i);
// note: don't pass client down; this would impact SkipReplies
await WriteResponseAsync(null, output, item);
PhysicalConnection.WriteMultiBulkHeader(output, -1);
}
}
break;
default:
throw new InvalidOperationException(
"Unexpected result type: " + value.Type);
else
{
var segment = value.Segment;
PhysicalConnection.WriteMultiBulkHeader(output, segment.Count);
var arr = segment.Array;
int offset = segment.Offset;
for (int i = 0; i < segment.Count; i++)
{
var item = arr[offset++];
if (item.IsNil)
throw new InvalidOperationException("Array element cannot be nil, index " + i);
// note: don't pass client down; this would impact SkipReplies
await WriteResponseAsync(null, output, item);
}
}
break;
default:
throw new InvalidOperationException(
"Unexpected result type: " + value.Type);
}
await output.FlushAsync().ConfigureAwait(false);
}
finally
{
if (client != null)
{
client.ReleasseWriteLock();
}
}
await output.FlushAsync().ConfigureAwait(false);
}
public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment