Commit 63421551 authored by Marc Gravell's avatar Marc Gravell

heap-allocated string-like type for representing commands; avoids the `string` alloc

parent 129ec274
using System;
using System.Text;
namespace StackExchange.Redis
{
public unsafe struct CommandBytes : IEquatable<CommandBytes>
{
public override int GetHashCode() => _hashcode;
public override string ToString()
{
fixed (byte* ptr = _bytes)
{
return Encoding.UTF8.GetString(ptr, Length);
}
}
public byte this[int index]
{
get
{
if (index < 0 || index >= Length) throw new IndexOutOfRangeException();
fixed (byte* ptr = _bytes)
{
return ptr[index];
}
}
}
public const int MaxLength = 32; // mut be multiple of 8
public int Length { get; }
readonly int _hashcode;
fixed byte _bytes[MaxLength];
public CommandBytes(string value)
{
value = value.ToLowerInvariant();
Length = Encoding.UTF8.GetByteCount(value);
if (Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed");
fixed (byte* bPtr = _bytes)
{
Clear((long*)bPtr);
fixed (char* cPtr = value)
{
Encoding.UTF8.GetBytes(cPtr, value.Length, bPtr, Length);
}
_hashcode = GetHashCode(bPtr, Length);
}
}
public override bool Equals(object obj) => obj is CommandBytes cb && Equals(cb);
public bool Equals(CommandBytes value)
{
if (_hashcode != value._hashcode || Length != value.Length)
return false;
fixed (byte* thisB = _bytes)
{
var thisL = (long*)thisB;
var otherL = (long*)value._bytes;
int chunks = (Length + 7) >> 3;
for (int i = 0; i < chunks; i++)
{
if (thisL[i] != otherL[i]) return false;
}
}
return true;
}
private static void Clear(long* ptr)
{
for (int i = 0; i < (MaxLength >> 3) ; i++)
{
ptr[i] = 0;
}
}
public CommandBytes(ReadOnlySpan<byte> value)
{
Length = value.Length;
if (Length > MaxLength) throw new ArgumentOutOfRangeException("Maximum command length exceeed");
fixed (byte* bPtr = _bytes)
{
Clear((long*)bPtr);
for (int i = 0; i < value.Length; i++)
{
bPtr[i] = ToLowerInvariant(value[i]);
}
_hashcode = GetHashCode(bPtr, Length);
}
}
static int GetHashCode(byte* ptr, int count)
{
var hc = count;
for (int i = 0; i < count; i++)
{
hc = (hc * -13547) + ptr[i];
}
return hc;
}
static byte ToLowerInvariant(byte b) => b >= 'A' && b <= 'Z' ? (byte)(b | 32) : b;
internal byte[] ToArray()
{
fixed (byte* ptr = _bytes)
{
return new Span<byte>(ptr, Length).ToArray();
}
}
}
}
...@@ -10,33 +10,31 @@ namespace StackExchange.Redis.Server ...@@ -10,33 +10,31 @@ namespace StackExchange.Redis.Server
private readonly RawResult _inner; private readonly RawResult _inner;
public int Count { get; } public int Count { get; }
public string Command { get; }
public override string ToString() => Command; public override string ToString() => Count == 0 ? "(n/a)" : GetString(0).ToString();
public override bool Equals(object obj) => throw new NotSupportedException(); public override bool Equals(object obj) => throw new NotSupportedException();
public TypedRedisValue WrongArgCount() => TypedRedisValue.Error($"ERR wrong number of arguments for '{Command}' command"); public TypedRedisValue WrongArgCount() => TypedRedisValue.Error($"ERR wrong number of arguments for '{ToString()}' command");
public TypedRedisValue CommandNotFound()
=> TypedRedisValue.Error($"ERR unknown command '{ToString()}'");
public TypedRedisValue UnknownSubcommandOrArgumentCount() => TypedRedisValue.Error($"ERR Unknown subcommand or wrong number of arguments for '{Command}'."); public TypedRedisValue UnknownSubcommandOrArgumentCount() => TypedRedisValue.Error($"ERR Unknown subcommand or wrong number of arguments for '{ToString()}'.");
public string GetString(int index) public string GetString(int index)
=> _inner[index].GetString(); => _inner[index].GetString();
public bool IsString(int index, string value) // TODO: optimize public bool IsString(int index, string value) // TODO: optimize
=> string.Equals(value, _inner[index].GetString(), StringComparison.OrdinalIgnoreCase); => string.Equals(value, _inner[index].GetString(), StringComparison.OrdinalIgnoreCase);
public override int GetHashCode() => throw new NotSupportedException(); public override int GetHashCode() => throw new NotSupportedException();
internal RedisRequest(RawResult result) internal RedisRequest(RawResult result)
: this(result, result.ItemsCount, result[0].GetString()) { }
private RedisRequest(RawResult inner, int count, string command)
{ {
_inner = inner; _inner = result;
Count = count; Count = result.ItemsCount;
Command = command;
} }
internal RedisRequest AsCommand(string command)
=> new RedisRequest(_inner, Count, command);
internal void Recycle() => _inner.Recycle(); internal void Recycle() => _inner.Recycle();
public RedisValue GetValue(int index) public RedisValue GetValue(int index)
...@@ -46,10 +44,41 @@ public int GetInt32(int index) ...@@ -46,10 +44,41 @@ public int GetInt32(int index)
=> (int)_inner[index].AsRedisValue(); => (int)_inner[index].AsRedisValue();
public long GetInt64(int index) => (long)_inner[index].AsRedisValue(); public long GetInt64(int index) => (long)_inner[index].AsRedisValue();
public RedisKey GetKey(int index) => _inner[index].AsRedisKey(); public RedisKey GetKey(int index) => _inner[index].AsRedisKey();
public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode) public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode)
=> _inner[index].AsRedisChannel(null, mode); => _inner[index].AsRedisChannel(null, mode);
internal bool TryGetCommandBytes(int i, out CommandBytes command)
{
var payload = _inner[i].DirecyPayload;
if (payload.Length > CommandBytes.MaxLength)
{
command = default;
return false;
}
if (payload.Length == 0)
{
command = default;
}
else if (payload.IsSingleSegment)
{
command = new CommandBytes(payload.First.Span);
}
else
{
Span<byte> span = stackalloc byte[CommandBytes.MaxLength];
var sliced = span;
foreach (var segment in payload)
{
segment.Span.CopyTo(sliced);
sliced = sliced.Slice(segment.Length);
}
command = new CommandBytes(span.Slice(0, (int)payload.Length));
}
return true;
}
} }
} }
...@@ -110,7 +110,7 @@ protected virtual TypedRedisValue ClientReply(RedisClient client, RedisRequest r ...@@ -110,7 +110,7 @@ protected virtual TypedRedisValue ClientReply(RedisClient client, RedisRequest r
[RedisCommand(-1)] [RedisCommand(-1)]
protected virtual TypedRedisValue Cluster(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Cluster(RedisClient client, RedisRequest request)
=> CommandNotFound(request.Command); => request.CommandNotFound();
[RedisCommand(-3)] [RedisCommand(-3)]
protected virtual TypedRedisValue Lpush(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Lpush(RedisClient client, RedisRequest request)
...@@ -476,25 +476,35 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request) ...@@ -476,25 +476,35 @@ private TypedRedisValue SubscribeImpl(RedisClient client, RedisRequest request)
{ {
var reply = TypedRedisValue.Rent(3 * (request.Count - 1), out var span); var reply = TypedRedisValue.Rent(3 * (request.Count - 1), out var span);
int index = 0; int index = 0;
var mode = request.Command[0] == 'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal; request.TryGetCommandBytes(0, out var cmd);
var cmdString = TypedRedisValue.BulkString(cmd.ToArray());
var mode = cmd[0] == (byte)'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal;
for (int i = 1; i < request.Count; i++) for (int i = 1; i < request.Count; i++)
{ {
var channel = request.GetChannel(i, mode); var channel = request.GetChannel(i, mode);
int count; int count;
switch (request.Command) if (s_Subscribe.Equals(cmd))
{ {
case "subscribe": count = client.Subscribe(channel); break; count = client.Subscribe(channel);
case "unsubscribe": count = client.Unsubscribe(channel); break;
default:
reply.Recycle(index);
return TypedRedisValue.Nil;
} }
span[index++] = TypedRedisValue.BulkString(request.Command); else if (s_Unsubscribe.Equals(cmd))
{
count = client.Unsubscribe(channel);
}
else
{
reply.Recycle(index);
return TypedRedisValue.Nil;
}
span[index++] = cmdString;
span[index++] = TypedRedisValue.BulkString((byte[])channel); span[index++] = TypedRedisValue.BulkString((byte[])channel);
span[index++] = TypedRedisValue.Integer(count); span[index++] = TypedRedisValue.Integer(count);
} }
return reply; return reply;
} }
static readonly CommandBytes
s_Subscribe = new CommandBytes("subscribe"),
s_Unsubscribe = new CommandBytes("unsubscribe");
static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
......
...@@ -22,14 +22,14 @@ public enum ShutdownReason ...@@ -22,14 +22,14 @@ public enum ShutdownReason
} }
private readonly List<RedisClient> _clients = new List<RedisClient>(); private readonly List<RedisClient> _clients = new List<RedisClient>();
private readonly TextWriter _output; private readonly TextWriter _output;
public RespServer(TextWriter output = null) public RespServer(TextWriter output = null)
{ {
_output = output; _output = output;
_commands = BuildCommands(this); _commands = BuildCommands(this);
} }
private static Dictionary<string, RespCommand> BuildCommands(RespServer server) private static Dictionary<CommandBytes, RespCommand> BuildCommands(RespServer server)
{ {
RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method) RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method)
{ {
...@@ -45,7 +45,7 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method) ...@@ -45,7 +45,7 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method)
select new RespCommand(attrib, method, server) into cmd select new RespCommand(attrib, method, server) into cmd
group cmd by cmd.Command; group cmd by cmd.Command;
var result = new Dictionary<string, RespCommand>(StringComparer.OrdinalIgnoreCase); var result = new Dictionary<CommandBytes, RespCommand>();
foreach (var grp in grouped) foreach (var grp in grouped)
{ {
RespCommand parent; RespCommand parent;
...@@ -58,7 +58,7 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method) ...@@ -58,7 +58,7 @@ RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method)
{ {
parent = grp.Single(); parent = grp.Single();
} }
result.Add(grp.Key, parent); result.Add(new CommandBytes(grp.Key), parent);
} }
return result; return result;
} }
...@@ -94,7 +94,7 @@ protected sealed class RedisCommandAttribute : Attribute ...@@ -94,7 +94,7 @@ protected sealed class RedisCommandAttribute : Attribute
public int Arity { get; } public int Arity { get; }
public bool LockFree { get; set; } public bool LockFree { get; set; }
} }
private readonly Dictionary<string, RespCommand> _commands; private readonly Dictionary<CommandBytes, RespCommand> _commands;
readonly struct RespCommand readonly struct RespCommand
{ {
...@@ -102,12 +102,14 @@ public RespCommand(RedisCommandAttribute attrib, MethodInfo method, RespServer s ...@@ -102,12 +102,14 @@ public RespCommand(RedisCommandAttribute attrib, MethodInfo method, RespServer s
{ {
_operation = (RespOperation)Delegate.CreateDelegate(typeof(RespOperation), server, method); _operation = (RespOperation)Delegate.CreateDelegate(typeof(RespOperation), server, method);
Command = (string.IsNullOrWhiteSpace(attrib.Command) ? method.Name : attrib.Command).Trim().ToLowerInvariant(); Command = (string.IsNullOrWhiteSpace(attrib.Command) ? method.Name : attrib.Command).Trim().ToLowerInvariant();
CommandBytes = new CommandBytes(Command);
SubCommand = attrib.SubCommand?.Trim()?.ToLowerInvariant(); SubCommand = attrib.SubCommand?.Trim()?.ToLowerInvariant();
Arity = attrib.Arity; Arity = attrib.Arity;
MaxArgs = attrib.MaxArgs; MaxArgs = attrib.MaxArgs;
LockFree = attrib.LockFree; LockFree = attrib.LockFree;
_subcommands = null; _subcommands = null;
} }
CommandBytes CommandBytes { get; }
public string Command { get; } public string Command { get; }
public string SubCommand { get; } public string SubCommand { get; }
public bool IsSubCommand => !string.IsNullOrEmpty(SubCommand); public bool IsSubCommand => !string.IsNullOrEmpty(SubCommand);
...@@ -126,7 +128,8 @@ private RespCommand(RespCommand parent, RespCommand[] subs) ...@@ -126,7 +128,8 @@ private RespCommand(RespCommand parent, RespCommand[] subs)
if (parent.HasSubCommands) throw new InvalidOperationException("Already has sub-commands"); if (parent.HasSubCommands) throw new InvalidOperationException("Already has sub-commands");
if (subs == null || subs.Length == 0) throw new InvalidOperationException("Cannot add empty sub-commands"); if (subs == null || subs.Length == 0) throw new InvalidOperationException("Cannot add empty sub-commands");
Command = parent.Command ?? subs[0].Command; Command = parent.Command;
CommandBytes = parent.CommandBytes;
SubCommand = parent.SubCommand; SubCommand = parent.SubCommand;
Arity = parent.Arity; Arity = parent.Arity;
MaxArgs = parent.MaxArgs; MaxArgs = parent.MaxArgs;
...@@ -180,7 +183,7 @@ internal int NetArity() ...@@ -180,7 +183,7 @@ internal int NetArity()
} }
delegate TypedRedisValue RespOperation(RedisClient client, RedisRequest request); delegate TypedRedisValue RespOperation(RedisClient client, RedisRequest request);
// for extensibility, so that a subclass can get their own client type // for extensibility, so that a subclass can get their own client type
// to be used via ListenForConnections // to be used via ListenForConnections
public virtual RedisClient CreateClient() => new RedisClient(); public virtual RedisClient CreateClient() => new RedisClient();
...@@ -394,14 +397,16 @@ async ValueTask<bool> Awaited(ValueTask wwrite, TypedRedisValue rresponse) ...@@ -394,14 +397,16 @@ async ValueTask<bool> Awaited(ValueTask wwrite, TypedRedisValue rresponse)
public TypedRedisValue Execute(RedisClient client, RedisRequest request) public TypedRedisValue Execute(RedisClient client, RedisRequest request)
{ {
if (string.IsNullOrWhiteSpace(request.Command)) return default; // not a request if (request.Count == 0) return default;// not a request
if (!request.TryGetCommandBytes(0, out var cmdBytes)) return request.CommandNotFound();
if (cmdBytes.Length == 0) return default; // not a request
Interlocked.Increment(ref _totalCommandsProcesed); Interlocked.Increment(ref _totalCommandsProcesed);
try try
{ {
TypedRedisValue result; TypedRedisValue result;
if (_commands.TryGetValue(request.Command, out var cmd)) if (_commands.TryGetValue(cmdBytes, out var cmd))
{ {
request = request.AsCommand(cmd.Command); // fixup casing
if (cmd.HasSubCommands) if (cmd.HasSubCommands)
{ {
cmd = cmd.Resolve(request); cmd = cmd.Resolve(request);
...@@ -426,21 +431,21 @@ public TypedRedisValue Execute(RedisClient client, RedisRequest request) ...@@ -426,21 +431,21 @@ public TypedRedisValue Execute(RedisClient client, RedisRequest request)
if (result.IsNil) if (result.IsNil)
{ {
Log($"missing command: '{request.Command}'"); Log($"missing command: '{request.GetString(0)}'");
return CommandNotFound(request.Command); return request.CommandNotFound();
} }
if (result.Type == ResultType.Error) Interlocked.Increment(ref _totalErrorCount); if (result.Type == ResultType.Error) Interlocked.Increment(ref _totalErrorCount);
return result; return result;
} }
catch (NotSupportedException) catch (NotSupportedException)
{ {
Log($"missing command: '{request.Command}'"); Log($"missing command: '{request.GetString(0)}'");
return CommandNotFound(request.Command); return request.CommandNotFound();
} }
catch (NotImplementedException) catch (NotImplementedException)
{ {
Log($"missing command: '{request.Command}'"); Log($"missing command: '{request.GetString(0)}'");
return CommandNotFound(request.Command); return request.CommandNotFound();
} }
catch (InvalidCastException) catch (InvalidCastException)
{ {
...@@ -460,9 +465,6 @@ internal static string ToLower(RawResult value) ...@@ -460,9 +465,6 @@ internal static string ToLower(RawResult value)
return val.ToLowerInvariant(); return val.ToLowerInvariant();
} }
protected static TypedRedisValue CommandNotFound(string command)
=> TypedRedisValue.Error($"ERR unknown command '{command}'");
[RedisCommand(1, LockFree = true)] [RedisCommand(1, LockFree = true)]
protected virtual TypedRedisValue Command(RedisClient client, RedisRequest request) protected virtual TypedRedisValue Command(RedisClient client, RedisRequest request)
{ {
...@@ -479,8 +481,9 @@ protected virtual TypedRedisValue CommandInfo(RedisClient client, RedisRequest r ...@@ -479,8 +481,9 @@ protected virtual TypedRedisValue CommandInfo(RedisClient client, RedisRequest r
var results = TypedRedisValue.Rent(request.Count - 2, out var span); var results = TypedRedisValue.Rent(request.Count - 2, out var span);
for (int i = 2; i < request.Count; i++) for (int i = 2; i < request.Count; i++)
{ {
span[i - 2] = _commands.TryGetValue(request.GetString(i), out var cmd) span[i - 2] = request.TryGetCommandBytes(i, out var cmdBytes)
? CommandInfo(cmd) : TypedRedisValue.NullArray; &&_commands.TryGetValue(cmdBytes, out var cmdInfo)
? CommandInfo(cmdInfo) : TypedRedisValue.NullArray;
} }
return results; return results;
} }
......
...@@ -13,11 +13,7 @@ ...@@ -13,11 +13,7 @@
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign> <PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
<LangVersion>latest</LangVersion> <LangVersion>latest</LangVersion>
<NoWarn>$(NoWarn);CS1591</NoWarn> <NoWarn>$(NoWarn);CS1591</NoWarn>
</PropertyGroup> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Debug|netstandard2.0|AnyCPU'">
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\StackExchange.Redis\StackExchange.Redis.csproj" /> <ProjectReference Include="..\StackExchange.Redis\StackExchange.Redis.csproj" />
......
...@@ -132,8 +132,7 @@ private TypedRedisValue(TypedRedisValue[] oversizedItems, int count) ...@@ -132,8 +132,7 @@ private TypedRedisValue(TypedRedisValue[] oversizedItems, int count)
} }
internal void Recycle(int limit = -1) internal void Recycle(int limit = -1)
{ {
var arr = _value.DirectObject as TypedRedisValue[]; if (_value.DirectObject is TypedRedisValue[] arr)
if (arr != null)
{ {
if (limit < 0) limit = (int)_value.DirectInt64; if (limit < 0) limit = (int)_value.DirectInt64;
for (int i = 0; i < limit; i++) for (int i = 0; i < limit; i++)
......
...@@ -21,6 +21,7 @@ namespace StackExchange.Redis ...@@ -21,6 +21,7 @@ namespace StackExchange.Redis
internal static readonly RawResult Nil = default; internal static readonly RawResult Nil = default;
private readonly ReadOnlySequence<byte> _payload; private readonly ReadOnlySequence<byte> _payload;
internal ReadOnlySequence<byte> DirecyPayload => _payload;
// note: can't use Memory<RawResult> here - struct recursion breaks runtimr // note: can't use Memory<RawResult> here - struct recursion breaks runtimr
private readonly RawResult[] _itemsOversized; private readonly RawResult[] _itemsOversized;
private readonly int _itemsCount; private readonly int _itemsCount;
......
...@@ -598,7 +598,7 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None) ...@@ -598,7 +598,7 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None)
// prepare the actual slaveof message (not sent yet) // prepare the actual slaveof message (not sent yet)
var slaveofMsg = CreateSlaveOfMessage(master, flags); var slaveofMsg = CreateSlaveOfMessage(master, flags);
var configuration = this.multiplexer.RawConfig; var configuration = multiplexer.RawConfig;
// attempt to cease having an opinion on the master; will resume that when replication completes // attempt to cease having an opinion on the master; will resume that when replication completes
// (note that this may fail; we aren't depending on it) // (note that this may fail; we aren't depending on it)
......
using System; using System;
using System.IO.Pipelines;
using System.Net; using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
using StackExchange.Redis.Server; using StackExchange.Redis.Server;
...@@ -7,14 +8,14 @@ static class Program ...@@ -7,14 +8,14 @@ static class Program
{ {
static async Task Main() static async Task Main()
{ {
//using (var pool = new DedicatedThreadPoolPipeScheduler(minWorkers: 10, maxWorkers: 10, using (var pool = new Pipelines.Sockets.Unofficial.DedicatedThreadPoolPipeScheduler(minWorkers: 10, maxWorkers: 50,
// priority: System.Threading.ThreadPriority.Highest)) priority: System.Threading.ThreadPriority.Highest))
using (var resp = new MemoryCacheRedisServer(Console.Out)) using (var resp = new MemoryCacheRedisServer(Console.Out))
using (var socket = new RespSocketServer(resp)) using (var socket = new RespSocketServer(resp))
{ {
//var options = new PipeOptions(readerScheduler: pool, writerScheduler: pool, useSynchronizationContext: false); var options = new PipeOptions(readerScheduler: pool, writerScheduler: pool, useSynchronizationContext: false);
socket.Listen(new IPEndPoint(IPAddress.Loopback, 6378) socket.Listen(new IPEndPoint(IPAddress.Loopback, 6378)
//, sendOptions: options, receiveOptions: options , sendOptions: options, receiveOptions: options
); );
await resp.Shutdown; await resp.Shutdown;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment