Commit 40c5d21d authored by Marc Gravell's avatar Marc Gravell

tidy up kestrel code; flush more eagerly when writing

parent 7906ad3b
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using Microsoft.Extensions.Logging;
using StackExchange.Redis.Server; using StackExchange.Redis.Server;
namespace KestrelRedisServer namespace KestrelRedisServer
{ {
public class RedisConnectionHandler : ConnectionHandler public class RedisConnectionHandler : ConnectionHandler
{ {
private readonly MemoryCacheRedisServer _server; private readonly RespServer _server;
public RedisConnectionHandler(ILogger<RedisConnectionHandler> logger) public RedisConnectionHandler(RespServer server) => _server = server;
{ public override Task OnConnectedAsync(ConnectionContext connection)
_server = new MemoryCacheRedisServer(); => _server.RunClient(connection.Transport);
}
public override async Task OnConnectedAsync(ConnectionContext connection)
{
var client = _server.AddClient();
try
{
while (true)
{
var read = await connection.Transport.Input.ReadAsync();
var buffer = read.Buffer;
bool makingProgress = false;
while (_server.TryProcessRequest(ref buffer, client, connection.Transport.Output))
{
makingProgress = true;
await connection.Transport.Output.FlushAsync();
}
connection.Transport.Input.AdvanceTo(buffer.Start, buffer.End);
if (!makingProgress && read.IsCompleted) break;
}
}
catch (ConnectionResetException) { } // swallow
finally
{
_server.RemoveClient(client);
connection.Transport.Input.Complete();
connection.Transport.Output.Complete();
}
}
} }
} }
using Microsoft.AspNetCore.Builder; using System;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using StackExchange.Redis.Server;
namespace KestrelRedisServer namespace KestrelRedisServer
{ {
public class Startup public class Startup : IDisposable
{ {
RespServer _server = new MemoryCacheRedisServer();
// This method gets called by the runtime. Use this method to add services to the container. // This method gets called by the runtime. Use this method to add services to the container.
// For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940 // For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services) public void ConfigureServices(IServiceCollection services)
{ => services.Add(new ServiceDescriptor(typeof(RespServer), _server));
}
public void Dispose() => _server.Dispose();
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline. // This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IHostingEnvironment env) public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{ {
if (env.IsDevelopment()) if (env.IsDevelopment()) app.UseDeveloperExceptionPage();
{ app.Run(context => context.Response.WriteAsync(_server.GetStats()));
app.UseDeveloperExceptionPage();
}
app.Run(async (context) =>
{
await context.Response.WriteAsync("Redis-ish server should be running");
});
} }
} }
} }
...@@ -16,12 +16,29 @@ public static bool IsMatch(string pattern, string key) ...@@ -16,12 +16,29 @@ public static bool IsMatch(string pattern, string key)
protected RedisServer(int databases = 16, TextWriter output = null) : base(output) protected RedisServer(int databases = 16, TextWriter output = null) : base(output)
{ {
if (databases < 1) throw new ArgumentOutOfRangeException(nameof(databases)); if (databases < 1) throw new ArgumentOutOfRangeException(nameof(databases));
Databases = databases;
var config = ServerConfiguration; var config = ServerConfiguration;
config["timeout"] = "0"; config["timeout"] = "0";
config["slave-read-only"] = "yes"; config["slave-read-only"] = "yes";
config["databases"] = databases.ToString(); config["databases"] = databases.ToString();
config["slaveof"] = ""; config["slaveof"] = "";
} }
protected override void AppendStats(StringBuilder sb)
{
base.AppendStats(sb);
sb.Append("Databases: ").Append(Databases).AppendLine();
lock (ServerSyncLock)
{
for (int i = 0; i < Databases; i++)
{
try
{
sb.Append("Database ").Append(i).Append(": ").Append(Dbsize(i)).AppendLine(" keys");
}
catch { }
}
}
}
public int Databases { get; } public int Databases { get; }
[RedisCommand(-3)] [RedisCommand(-3)]
...@@ -374,7 +391,7 @@ StringBuilder AddHeader() ...@@ -374,7 +391,7 @@ StringBuilder AddHeader()
break; break;
case "Stats": case "Stats":
AddHeader().Append("total_connections_received:").Append(TotalClientCount).AppendLine() AddHeader().Append("total_connections_received:").Append(TotalClientCount).AppendLine()
.Append("total_commands_processed:").Append(CommandsProcesed).AppendLine(); .Append("total_commands_processed:").Append(TotalCommandsProcesed).AppendLine();
break; break;
case "Replication": case "Replication":
AddHeader().AppendLine("role:master"); AddHeader().AppendLine("role:master");
......
...@@ -59,6 +59,21 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method) ...@@ -59,6 +59,21 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method)
} }
return result; return result;
} }
public string GetStats()
{
var sb = new StringBuilder();
AppendStats(sb);
return sb.ToString();
}
protected virtual void AppendStats(StringBuilder sb)
{
sb.Append("Current clients:\t").Append(ClientCount).AppendLine()
.Append("Total clients:\t").Append(TotalClientCount).AppendLine()
.Append("Total operations:\t").Append(TotalCommandsProcesed).AppendLine()
.Append("Error replies:\t").Append(TotalErrorCount).AppendLine();
}
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
protected sealed class RedisCommandAttribute : Attribute protected sealed class RedisCommandAttribute : Attribute
{ {
...@@ -171,8 +186,9 @@ protected int TcpPort() ...@@ -171,8 +186,9 @@ protected int TcpPort()
} }
private Action<object> _runClientCallback; private Action<object> _runClientCallback;
// KeepAlive here just to make the compiler happy that we've done *something* with the task
private Action<object> RunClientCallback => _runClientCallback ?? private Action<object> RunClientCallback => _runClientCallback ??
(_runClientCallback = state => RunClient((RedisClient)state)); (_runClientCallback = state => GC.KeepAlive(RunClient((IDuplexPipe)state)));
public void Listen( public void Listen(
EndPoint endpoint, EndPoint endpoint,
...@@ -212,8 +228,8 @@ public RedisClient AddClient() ...@@ -212,8 +228,8 @@ public RedisClient AddClient()
var client = CreateClient(); var client = CreateClient();
lock (_clients) lock (_clients)
{ {
client.Id = ++_nextId;
ThrowIfShutdown(); ThrowIfShutdown();
client.Id = ++_nextId;
_clients.Add(client); _clients.Add(client);
TotalClientCount++; TotalClientCount++;
} }
...@@ -221,7 +237,7 @@ public RedisClient AddClient() ...@@ -221,7 +237,7 @@ public RedisClient AddClient()
} }
public bool RemoveClient(RedisClient client) public bool RemoveClient(RedisClient client)
{ {
if (client == null) throw new ArgumentNullException(nameof(client)); if (client == null) return false;
lock (_clients) lock (_clients)
{ {
client.Closed = true; client.Closed = true;
...@@ -237,16 +253,14 @@ private async void ListenForConnections(PipeOptions sendOptions, PipeOptions rec ...@@ -237,16 +253,14 @@ private async void ListenForConnections(PipeOptions sendOptions, PipeOptions rec
var client = await _listener.AcceptAsync(); var client = await _listener.AcceptAsync();
SocketConnection.SetRecommendedServerOptions(client); SocketConnection.SetRecommendedServerOptions(client);
var pipe = SocketConnection.Create(client, sendOptions, receiveOptions); var pipe = SocketConnection.Create(client, sendOptions, receiveOptions);
var c = AddClient(); StartOnScheduler(receiveOptions.ReaderScheduler, RunClientCallback, pipe);
c.LinkedPipe = pipe;
StartOnScheduler(receiveOptions.ReaderScheduler, RunClientCallback, c);
} }
} }
catch (NullReferenceException) { } catch (NullReferenceException) { }
catch (ObjectDisposedException) { } catch (ObjectDisposedException) { }
catch (Exception ex) catch (Exception ex)
{ {
if(!_isShutdown) Log("Listener faulted: " + ex.Message); if (!_isShutdown) Log("Listener faulted: " + ex.Message);
} }
} }
...@@ -281,33 +295,27 @@ protected virtual void Dispose(bool disposing) ...@@ -281,33 +295,27 @@ protected virtual void Dispose(bool disposing)
} }
} }
async void RunClient(RedisClient client) public async Task RunClient(IDuplexPipe pipe)
{ {
ThrowIfShutdown();
var input = client?.LinkedPipe?.Input;
var output = client?.LinkedPipe?.Output;
if (input == null || output == null) return; // nope
Exception fault = null; Exception fault = null;
RedisClient client = null;
try try
{ {
client = AddClient();
while (!client.Closed) while (!client.Closed)
{ {
var readResult = await input.ReadAsync(); var readResult = await pipe.Input.ReadAsync();
var buffer = readResult.Buffer; var buffer = readResult.Buffer;
bool makingProgress = false; bool makingProgress = false;
while (!client.Closed && TryProcessRequest(ref buffer, client, output)) while (!client.Closed && await TryProcessRequestAsync(ref buffer, client, pipe.Output))
{ {
makingProgress = true; makingProgress = true;
await output.FlushAsync();
} }
input.AdvanceTo(buffer.Start, buffer.End); pipe.Input.AdvanceTo(buffer.Start, buffer.End);
if (!makingProgress && readResult.IsCompleted) if (!makingProgress && readResult.IsCompleted)
{ { // nothing to do, and nothing more will be arriving
break; break;
} }
} }
...@@ -317,8 +325,9 @@ async void RunClient(RedisClient client) ...@@ -317,8 +325,9 @@ async void RunClient(RedisClient client)
catch (Exception ex) { fault = ex; } catch (Exception ex) { fault = ex; }
finally finally
{ {
try { input.Complete(fault); } catch { } RemoveClient(client);
try { output.Complete(fault); } catch { } try { pipe.Input.Complete(fault); } catch { }
try { pipe.Output.Complete(fault); } catch { }
if (fault != null && !_isShutdown) if (fault != null && !_isShutdown)
{ {
...@@ -339,14 +348,29 @@ private void Log(string message) ...@@ -339,14 +348,29 @@ private void Log(string message)
} }
static Encoder s_sharedEncoder; // swapped in/out to avoid alloc on the public WriteResponse API static Encoder s_sharedEncoder; // swapped in/out to avoid alloc on the public WriteResponse API
public static void WriteResponse(RedisClient client, PipeWriter output, RedisResult response) public static ValueTask WriteResponseAsync(RedisClient client, PipeWriter output, RedisResult response)
{ {
async ValueTask Awaited(ValueTask wwrite, Encoder eenc)
{
await wwrite;
Interlocked.Exchange(ref s_sharedEncoder, eenc);
}
var enc = Interlocked.Exchange(ref s_sharedEncoder, null) ?? Encoding.UTF8.GetEncoder(); var enc = Interlocked.Exchange(ref s_sharedEncoder, null) ?? Encoding.UTF8.GetEncoder();
WriteResponse(client, output, response, enc); var write = WriteResponseAsync(client, output, response, enc);
if (!write.IsCompletedSuccessfully) return Awaited(write, enc);
Interlocked.Exchange(ref s_sharedEncoder, enc); Interlocked.Exchange(ref s_sharedEncoder, enc);
return default;
} }
internal static void WriteResponse(RedisClient client, PipeWriter output, RedisResult response, Encoder encoder)
internal static async ValueTask WriteResponseAsync(RedisClient client, PipeWriter output, RedisResult response, Encoder encoder)
{
void WritePrefix(PipeWriter ooutput, char pprefix)
{ {
var span = ooutput.GetSpan(1);
span[0] = (byte)pprefix;
ooutput.Advance(1);
}
if (response == null) return; // not actually a request (i.e. empty/whitespace request) if (response == null) return; // not actually a request (i.e. empty/whitespace request)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
char prefix; char prefix;
...@@ -361,9 +385,7 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR ...@@ -361,9 +385,7 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR
case ResultType.SimpleString: case ResultType.SimpleString:
prefix = '+'; prefix = '+';
BasicMessage: BasicMessage:
var span = output.GetSpan(1); WritePrefix(output, prefix);
span[0] = (byte)prefix;
output.Advance(1);
var val = response.AsString(); var val = response.AsString();
...@@ -388,7 +410,9 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR ...@@ -388,7 +410,9 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR
var item = arr[i]; var item = arr[i];
if (item == null) if (item == null)
throw new InvalidOperationException("Array element cannot be null, index " + i); throw new InvalidOperationException("Array element cannot be null, index " + i);
WriteResponse(null, output, item, encoder); // note: don't pass client down; this would impact SkipReplies
// note: don't pass client down; this would impact SkipReplies
await WriteResponseAsync(null, output, item, encoder);
} }
} }
break; break;
...@@ -396,6 +420,7 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR ...@@ -396,6 +420,7 @@ internal static void WriteResponse(RedisClient client, PipeWriter output, RedisR
throw new InvalidOperationException( throw new InvalidOperationException(
"Unexpected result type: " + response.Type); "Unexpected result type: " + response.Type);
} }
await output.FlushAsync();
} }
public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request) public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request)
{ {
...@@ -411,32 +436,39 @@ public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisR ...@@ -411,32 +436,39 @@ public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisR
return false; return false;
} }
public bool TryProcessRequest(ref ReadOnlySequence<byte> buffer, RedisClient client, PipeWriter output) public ValueTask<bool> TryProcessRequestAsync(ref ReadOnlySequence<byte> buffer, RedisClient client, PipeWriter output)
{
async ValueTask<bool> Awaited(ValueTask wwrite)
{ {
await wwrite;
return true;
}
if (!buffer.IsEmpty && TryParseRequest(ref buffer, out var request)) if (!buffer.IsEmpty && TryParseRequest(ref buffer, out var request))
{ {
RedisResult response; RedisResult response;
try { response = Execute(client, request); } try { response = Execute(client, request); }
finally { request.Recycle(); } finally { request.Recycle(); }
WriteResponse(client, output, response); var write = WriteResponseAsync(client, output, response);
return true; if (!write.IsCompletedSuccessfully) return Awaited(write);
return new ValueTask<bool>(true);
} }
return false; return new ValueTask<bool>(false);
} }
private object ServerSyncLock => this; protected object ServerSyncLock => this;
private long _commandsProcesed; private long _totalCommandsProcesed, _totalErrorCount;
public long CommandsProcesed => _commandsProcesed; public long TotalCommandsProcesed => _totalCommandsProcesed;
public long TotalErrorCount => _totalErrorCount;
public RedisResult Execute(RedisClient client, RedisRequest request) public RedisResult Execute(RedisClient client, RedisRequest request)
{ {
if (string.IsNullOrWhiteSpace(request.Command)) return null; // not a request if (string.IsNullOrWhiteSpace(request.Command)) return null; // not a request
Interlocked.Increment(ref _commandsProcesed); Interlocked.Increment(ref _totalCommandsProcesed);
try try
{ {
RedisResult result; RedisResult result;
if(_commands.TryGetValue(request.Command, out var cmd)) if (_commands.TryGetValue(request.Command, out var cmd))
{ {
request = request.AsCommand(cmd.Command); // fixup casing request = request.AsCommand(cmd.Command); // fixup casing
if (cmd.HasSubCommands) if (cmd.HasSubCommands)
...@@ -444,13 +476,13 @@ public RedisResult Execute(RedisClient client, RedisRequest request) ...@@ -444,13 +476,13 @@ public RedisResult Execute(RedisClient client, RedisRequest request)
cmd = cmd.Resolve(request); cmd = cmd.Resolve(request);
if (cmd.IsUnknown) return request.UnknownSubcommandOrArgumentCount(); if (cmd.IsUnknown) return request.UnknownSubcommandOrArgumentCount();
} }
if(cmd.LockFree) if (cmd.LockFree)
{ {
result = cmd.Execute(client, request); result = cmd.Execute(client, request);
} }
else else
{ {
lock(ServerSyncLock) lock (ServerSyncLock)
{ {
result = cmd.Execute(client, request); result = cmd.Execute(client, request);
} }
...@@ -462,6 +494,7 @@ public RedisResult Execute(RedisClient client, RedisRequest request) ...@@ -462,6 +494,7 @@ public RedisResult Execute(RedisClient client, RedisRequest request)
} }
if (result == null) Log($"missing command: '{request.Command}'"); if (result == null) Log($"missing command: '{request.Command}'");
else if (result.Type == ResultType.Error) Interlocked.Increment(ref _totalErrorCount);
return result ?? CommandNotFound(request.Command); return result ?? CommandNotFound(request.Command);
} }
catch (NotSupportedException) catch (NotSupportedException)
...@@ -480,7 +513,7 @@ public RedisResult Execute(RedisClient client, RedisRequest request) ...@@ -480,7 +513,7 @@ public RedisResult Execute(RedisClient client, RedisRequest request)
} }
catch (Exception ex) catch (Exception ex)
{ {
if(!_isShutdown) Log(ex.Message); if (!_isShutdown) Log(ex.Message);
return RedisResult.Create("ERR " + ex.Message, ResultType.Error); return RedisResult.Create("ERR " + ex.Message, ResultType.Error);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment