Commit f82093d8 authored by Marc Gravell's avatar Marc Gravell

implement pub/sub

parent 68af2f79
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server namespace StackExchange.Redis.Server
{ {
...@@ -20,16 +22,39 @@ internal bool ShouldSkipResponse() ...@@ -20,16 +22,39 @@ internal bool ShouldSkipResponse()
public int SubscriptionCount => _subscripions?.Count ?? 0; public int SubscriptionCount => _subscripions?.Count ?? 0;
internal int Subscribe(RedisChannel channel) internal int Subscribe(RedisChannel channel)
{ {
if (_subscripions == null) _subscripions = new HashSet<RedisChannel>(); var subs = _subscripions;
_subscripions.Add(channel); if(subs == null)
return _subscripions.Count; {
subs = new HashSet<RedisChannel>(); // but need to watch for compete
subs = Interlocked.CompareExchange(ref _subscripions, subs, null) ?? subs;
}
lock (subs)
{
subs.Add(channel);
return subs.Count;
}
} }
internal int Unsubscribe(RedisChannel channel) internal int Unsubscribe(RedisChannel channel)
{ {
if (_subscripions == null) return 0; var subs = _subscripions;
_subscripions.Remove(channel); if (subs == null) return 0;
return _subscripions.Count; lock (subs)
{
subs.Remove(channel);
return subs.Count;
}
} }
internal bool IsSubscribed(RedisChannel channel)
{
var subs = _subscripions;
if (subs == null) return false;
lock (subs)
{
return subs.Contains(channel);
}
}
public int Database { get; set; } public int Database { get; set; }
public string Name { get; set; } public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; } internal IDuplexPipe LinkedPipe { get; set; }
...@@ -50,5 +75,9 @@ public void Dispose() ...@@ -50,5 +75,9 @@ public void Dispose()
if (pipe is IDisposable d) try { d.Dispose(); } catch { } if (pipe is IDisposable d) try { d.Dispose(); } catch { }
} }
} }
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
internal Task TakeWriteLockAsync() => _writeLock.WaitAsync();
internal void ReleasseWriteLock() => _writeLock.Release();
} }
} }
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Text; using System.Text;
using System.Threading.Tasks;
namespace StackExchange.Redis.Server namespace StackExchange.Redis.Server
{ {
...@@ -435,7 +437,20 @@ protected virtual TypedRedisValue Mset(RedisClient client, RedisRequest request) ...@@ -435,7 +437,20 @@ protected virtual TypedRedisValue Mset(RedisClient client, RedisRequest request)
} }
[RedisCommand(-1, LockFree = true, MaxArgs = 2)] [RedisCommand(-1, LockFree = true, MaxArgs = 2)]
protected virtual TypedRedisValue Ping(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Ping(RedisClient client, RedisRequest request)
=> TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1)); {
if (client.SubscriptionCount == 0)
{
return TypedRedisValue.SimpleString(request.Count == 1 ? "PONG" : request.GetString(1));
}
else
{
// strictly speaking this is a >=3.0 feature
var pong = TypedRedisValue.Rent(2, out var span);
span[0] = TypedRedisValue.BulkString("pong");
span[1] = TypedRedisValue.BulkString(request.Count == 1 ? RedisValue.EmptyString : request.GetValue(1));
return pong;
}
}
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Quit(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Quit(RedisClient client, RedisRequest request)
...@@ -465,10 +480,11 @@ protected virtual TypedRedisValue Select(RedisClient client, RedisRequest reques ...@@ -465,10 +480,11 @@ protected virtual TypedRedisValue Select(RedisClient client, RedisRequest reques
return TypedRedisValue.OK; return TypedRedisValue.OK;
} }
[RedisCommand(-2)] [RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Subscribe(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Subscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request); => SubscribeImpl(client, request);
[RedisCommand(-2)]
[RedisCommand(-2, LockFree = true)]
protected virtual TypedRedisValue Unsubscribe(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Unsubscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request); => SubscribeImpl(client, request);
...@@ -507,6 +523,66 @@ private static readonly CommandBytes ...@@ -507,6 +523,66 @@ private static readonly CommandBytes
s_Unsubscribe = new CommandBytes("unsubscribe"); s_Unsubscribe = new CommandBytes("unsubscribe");
private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); private static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
static readonly RedisValue s_MESSAGE = Encoding.ASCII.GetBytes("message");
static TypedRedisValue CreateBroadcastMessage(RedisChannel channel, RedisValue payload)
{
var msg = TypedRedisValue.Rent(3, out var span);
span[0] = TypedRedisValue.BulkString(s_MESSAGE);
span[1] = TypedRedisValue.BulkString(channel.Value);
span[2] = TypedRedisValue.BulkString(payload);
return msg;
}
private async ValueTask<bool> TrySendOutOfBandAsync(RedisClient client, TypedRedisValue value)
{
try
{
var output = client?.LinkedPipe?.Output;
if (output == null)
{
Console.WriteLine("No pipe");
return false;
}
await WriteResponseAsync(client, output, value);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
return false;
}
}
private async Task BackgroundPublish(ArraySegment<RedisClient> clients, RedisChannel channel, RedisValue payload)
{
var msg = CreateBroadcastMessage(channel, payload);
foreach (var sub in clients)
{
await TrySendOutOfBandAsync(sub, msg);
}
// only recycle on success, to avoid issues
msg.Recycle();
ArrayPool<RedisClient>.Shared.Return(clients.Array);
}
[RedisCommand(3, LockFree = true)]
protected virtual TypedRedisValue Publish(RedisClient client, RedisRequest request)
{
var channel = request.GetChannel(1, RedisChannel.PatternMode.Literal);
var subscribers = FilterSubscribers(channel);
int count = subscribers.Count;
if (count != 0)
{
var payload = request.GetValue(2);
Task.Run(() => BackgroundPublish(subscribers, channel, payload));
}
return TypedRedisValue.Integer(count);
}
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Time(RedisClient client, RedisRequest request)
{ {
......
...@@ -189,17 +189,41 @@ internal int NetArity() ...@@ -189,17 +189,41 @@ internal int NetArity()
// for extensibility, so that a subclass can get their own client type // for extensibility, so that a subclass can get their own client type
// to be used via ListenForConnections // to be used via ListenForConnections
public virtual RedisClient CreateClient() => new RedisClient(); public virtual RedisClient CreateClient(IDuplexPipe pipe) => new RedisClient { LinkedPipe = pipe };
public int ClientCount public int ClientCount
{ {
get { lock (_clients) { return _clients.Count; } } get { lock (_clients) { return _clients.Count; } }
} }
protected ArraySegment<RedisClient> FilterSubscribers(RedisChannel channel)
{
lock (_clients)
{
var arr = ArrayPool<RedisClient>.Shared.Rent(_clients.Count);
int count = 0;
foreach(var client in _clients)
{
if (client.IsSubscribed(channel))
arr[count++] = client;
}
if (count == 0)
{
ArrayPool<RedisClient>.Shared.Return(arr);
return default; // Count=0, importantly
}
else
{
return new ArraySegment<RedisClient>(arr, 0, count);
}
}
}
public int TotalClientCount { get; private set; } public int TotalClientCount { get; private set; }
private int _nextId; private int _nextId;
public RedisClient AddClient() public RedisClient AddClient(IDuplexPipe pipe)
{ {
var client = CreateClient(); var client = CreateClient(pipe);
lock (_clients) lock (_clients)
{ {
ThrowIfShutdown(); ThrowIfShutdown();
...@@ -250,7 +274,7 @@ public async Task RunClientAsync(IDuplexPipe pipe) ...@@ -250,7 +274,7 @@ public async Task RunClientAsync(IDuplexPipe pipe)
RedisClient client = null; RedisClient client = null;
try try
{ {
client = AddClient(); client = AddClient(pipe);
while (!client.Closed) while (!client.Closed)
{ {
var readResult = await pipe.Input.ReadAsync().ConfigureAwait(false); var readResult = await pipe.Input.ReadAsync().ConfigureAwait(false);
...@@ -314,6 +338,13 @@ void WritePrefix(PipeWriter ooutput, char pprefix) ...@@ -314,6 +338,13 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request) if (value.IsNil) return; // not actually a request (i.e. empty/whitespace request)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
if (client != null)
{
await client.TakeWriteLockAsync();
}
try
{
char prefix; char prefix;
switch (value.Type) switch (value.Type)
{ {
...@@ -363,6 +394,14 @@ void WritePrefix(PipeWriter ooutput, char pprefix) ...@@ -363,6 +394,14 @@ void WritePrefix(PipeWriter ooutput, char pprefix)
} }
await output.FlushAsync().ConfigureAwait(false); await output.FlushAsync().ConfigureAwait(false);
} }
finally
{
if (client != null)
{
client.ReleasseWriteLock();
}
}
}
public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request) public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request)
{ {
var reader = new BufferReader(buffer); var reader = new BufferReader(buffer);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment