Unverified Commit fe830a47 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Implement a redis-server, to help testing (#894)

* v0.0 of test server

* server very much doesn't work, but: pieces are in place

* it's alive!

* implement MemoryCache server

* be a little pickier about which faults we actually report

* implement **very badly** pre-RESP protocol for server; this is just to test PING_INLINE etc, and it *finds a catastrophic case*

* make shutdown much more graceful; Databases should be immutable; enable ConnectionExecute vs ServerExecute concept; add awaitable server shutdown task

* make MemoryCache usage standalone, and enable flush; implement cusstom box/unbox operations on RedisValue

* implement basic "info"

* limit scope of GetCurrentProcess

* set slaveof in config to make clients happier

* rename server types; implement MEMORY PURGE and TIME; make types explicit throughout in RedisServer

* implement KEYS; prefer NotSupportedException to NotImplementedException, but recognize both

* implement UNLINK; better handling of nil requests

* implement basic set ops

* implement STRLEN; handle WRONGTYPE

* convention / reflection based command registration

* overhaul how arity works so we can implement COMMAND; support null arrays

* make sure we can parse null arras is RedisResult

* set server socket options

* fix error handling incomplete lines in the "inline" protocol

* move ParseInlineProtocol out, but: still not implemented

* implement CLIENT REPLY and add the "inline" protocol

* need to support either quote tokenizer

* add readme to the server code

* accessibility on sample code

* implement the last of the commands needed for redis-benchmark

* fix naming
parent 42f95161
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Caching;
using System.Runtime.CompilerServices;
namespace StackExchange.Redis.Server
{
public class MemoryCacheRedisServer : RedisServer
{
public MemoryCacheRedisServer(TextWriter output = null) : base(1, output)
=> CreateNewCache();
private MemoryCache _cache;
private void CreateNewCache()
{
var old = _cache;
_cache = new MemoryCache(GetType().Name);
if (old != null) old.Dispose();
}
protected override void Dispose(bool disposing)
{
if (disposing) _cache.Dispose();
base.Dispose(disposing);
}
protected override long Dbsize(int database) => _cache.GetCount();
protected override RedisValue Get(int database, RedisKey key)
=> RedisValue.Unbox(_cache[key]);
protected override void Set(int database, RedisKey key, RedisValue value)
=> _cache[key] = value.Box();
protected override bool Del(int database, RedisKey key)
=> _cache.Remove(key) != null;
protected override void Flushdb(int database)
=> CreateNewCache();
protected override bool Exists(int database, RedisKey key)
=> _cache.Contains(key);
protected override IEnumerable<RedisKey> Keys(int database, RedisKey pattern)
{
string s = pattern;
foreach (var pair in _cache)
{
if (IsMatch(pattern, pair.Key)) yield return pair.Key;
}
}
protected override bool Sadd(int database, RedisKey key, RedisValue value)
=> GetSet(key, true).Add(value);
protected override bool Sismember(int database, RedisKey key, RedisValue value)
=> GetSet(key, false)?.Contains(value) ?? false;
protected override bool Srem(int database, RedisKey key, RedisValue value)
{
var set = GetSet(key, false);
if (set != null && set.Remove(value))
{
if (set.Count == 0) _cache.Remove(key);
return true;
}
return false;
}
protected override long Scard(int database, RedisKey key)
=> GetSet(key, false)?.Count ?? 0;
HashSet<RedisValue> GetSet(RedisKey key, bool create)
{
var set = (HashSet<RedisValue>)_cache[key];
if (set == null && create)
{
set = new HashSet<RedisValue>();
_cache[key] = set;
}
return set;
}
protected override RedisValue Spop(int database, RedisKey key)
{
var set = GetSet(key, false);
if (set == null) return RedisValue.Null;
var result = set.First();
set.Remove(result);
if (set.Count == 0) _cache.Remove(key);
return result;
}
protected override long Lpush(int database, RedisKey key, RedisValue value)
{
var stack = GetStack(key, true);
stack.Push(value);
return stack.Count;
}
protected override RedisValue Lpop(int database, RedisKey key)
{
var stack = GetStack(key, false);
if (stack == null) return RedisValue.Null;
var val = stack.Pop();
if(stack.Count == 0) _cache.Remove(key);
return val;
}
protected override long Llen(int database, RedisKey key)
=> GetStack(key, false)?.Count ?? 0;
[MethodImpl(MethodImplOptions.NoInlining)]
static void ThrowArgumentOutOfRangeException() => throw new ArgumentOutOfRangeException();
protected override void LRange(int database, RedisKey key, long start, RedisValue[] arr)
{
var stack = GetStack(key, false);
using (var iter = stack.GetEnumerator())
{
// skip
while (start-- > 0) if (!iter.MoveNext()) ThrowArgumentOutOfRangeException();
// take
for (int i = 0; i < arr.Length; i++)
{
if (!iter.MoveNext()) ThrowArgumentOutOfRangeException();
arr[i] = iter.Current;
}
}
}
Stack<RedisValue> GetStack(RedisKey key, bool create)
{
var stack = (Stack<RedisValue>)_cache[key];
if (stack == null && create)
{
stack = new Stack<RedisValue>();
_cache[key] = stack;
}
return stack;
}
}
}
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
namespace StackExchange.Redis.Server
{
public sealed class RedisClient : IDisposable
{
internal int SkipReplies { get; set; }
internal bool ShouldSkipResponse()
{
if (SkipReplies > 0)
{
SkipReplies--;
return true;
}
return false;
}
private HashSet<RedisChannel> _subscripions;
public int SubscriptionCount => _subscripions?.Count ?? 0;
internal int Subscribe(RedisChannel channel)
{
if (_subscripions == null) _subscripions = new HashSet<RedisChannel>();
_subscripions.Add(channel);
return _subscripions.Count;
}
internal int Unsubscribe(RedisChannel channel)
{
if (_subscripions == null) return 0;
_subscripions.Remove(channel);
return _subscripions.Count;
}
public int Database { get; set; }
public string Name { get; set; }
internal IDuplexPipe LinkedPipe { get; set; }
public bool Closed { get; internal set; }
public void Dispose()
{
Closed = true;
var pipe = LinkedPipe;
LinkedPipe = null;
if (pipe != null)
{
try { pipe.Input.CancelPendingRead(); } catch { }
try { pipe.Input.Complete(); } catch { }
try { pipe.Output.CancelPendingFlush(); } catch { }
try { pipe.Output.Complete(); } catch { }
if (pipe is IDisposable d) try { d.Dispose(); } catch { }
}
}
}
}
using System;
namespace StackExchange.Redis.Server
{
public readonly ref struct RedisRequest
{ // why ref? don't *really* need it, but: these things are "in flight"
// based on an open RawResult (which is just the detokenized ReadOnlySequence<byte>)
// so: using "ref" makes it clear that you can't expect to store these and have
// them keep working
private readonly RawResult _inner;
public int Count { get; }
public string Command { get; }
public override string ToString() => Command;
public override bool Equals(object obj) => throw new NotSupportedException();
public RedisResult WrongArgCount() => RedisResult.Create($"ERR wrong number of arguments for '{Command}' command", ResultType.Error);
public RedisResult UnknownSubcommandOrArgumentCount() => RedisResult.Create($"ERR Unknown subcommand or wrong number of arguments for '{Command}'.", ResultType.Error);
public string GetString(int index)
=> _inner[index].GetString();
internal RedisResult GetResult(int index)
=> RedisResult.Create(_inner[index].AsRedisValue());
public bool IsString(int index, string value) // TODO: optimize
=> string.Equals(value, _inner[index].GetString(), StringComparison.OrdinalIgnoreCase);
public override int GetHashCode() => throw new NotSupportedException();
internal RedisRequest(RawResult result)
: this(result, result.ItemsCount, result[0].GetString()) { }
private RedisRequest(RawResult inner, int count, string command)
{
_inner = inner;
Count = count;
Command = command;
}
internal RedisRequest AsCommand(string command)
=> new RedisRequest(_inner, Count, command);
public void Recycle() => _inner.Recycle();
public RedisValue GetValue(int index)
=> _inner[index].AsRedisValue();
public int GetInt32(int index)
=> (int)_inner[index].AsRedisValue();
public long GetInt64(int index) => (long)_inner[index].AsRedisValue();
public RedisKey GetKey(int index) => _inner[index].AsRedisKey();
public RedisChannel GetChannel(int index, RedisChannel.PatternMode mode)
=> _inner[index].AsRedisChannel(null, mode);
}
}
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
namespace StackExchange.Redis.Server
{
public abstract class RedisServer : RespServer
{
public static bool IsMatch(string pattern, string key)
{
// non-trivial wildcards not implemented yet!
return pattern == "*" || string.Equals(pattern, key, StringComparison.OrdinalIgnoreCase);
}
protected RedisServer(int databases = 16, TextWriter output = null) : base(output)
{
if (databases < 1) throw new ArgumentOutOfRangeException(nameof(databases));
var config = ServerConfiguration;
config["timeout"] = "0";
config["slave-read-only"] = "yes";
config["databases"] = databases.ToString();
config["slaveof"] = "";
}
public int Databases { get; }
[RedisCommand(-3)]
protected virtual RedisResult Sadd(RedisClient client, RedisRequest request)
{
int added = 0;
var key = request.GetKey(1);
for (int i = 2; i < request.Count; i++)
{
if (Sadd(client.Database, key, request.GetValue(i)))
added++;
}
return RedisResult.Create(added, ResultType.Integer);
}
protected virtual bool Sadd(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
[RedisCommand(-3)]
protected virtual RedisResult Srem(RedisClient client, RedisRequest request)
{
int removed = 0;
var key = request.GetKey(1);
for (int i = 2; i < request.Count; i++)
{
if (Srem(client.Database, key, request.GetValue(i)))
removed++;
}
return RedisResult.Create(removed, ResultType.Integer);
}
protected virtual bool Srem(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
[RedisCommand(2)]
protected virtual RedisResult Spop(RedisClient client, RedisRequest request)
=> RedisResult.Create(Spop(client.Database, request.GetKey(1)), ResultType.BulkString);
protected virtual RedisValue Spop(int database, RedisKey key) => throw new NotSupportedException();
[RedisCommand(2)]
protected virtual RedisResult Scard(RedisClient client, RedisRequest request)
=> RedisResult.Create(Scard(client.Database, request.GetKey(1)), ResultType.Integer);
protected virtual long Scard(int database, RedisKey key) => throw new NotSupportedException();
[RedisCommand(3)]
protected virtual RedisResult Sismember(RedisClient client, RedisRequest request)
=> Sismember(client.Database, request.GetKey(1), request.GetValue(2)) ? RedisResult.One : RedisResult.Zero;
protected virtual bool Sismember(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
[RedisCommand(3, "client", "setname", LockFree = true)]
protected virtual RedisResult ClientSetname(RedisClient client, RedisRequest request)
{
client.Name = request.GetString(2);
return RedisResult.OK;
}
[RedisCommand(2, "client", "getname", LockFree = true)]
protected virtual RedisResult ClientGetname(RedisClient client, RedisRequest request)
=> RedisResult.Create(client.Name, ResultType.BulkString);
[RedisCommand(3, "client", "reply", LockFree = true)]
protected virtual RedisResult ClientReply(RedisClient client, RedisRequest request)
{
if (request.IsString(2, "on")) client.SkipReplies = -1; // reply to nothing
else if (request.IsString(2, "off")) client.SkipReplies = 0; // reply to everything
else if (request.IsString(2, "skip")) client.SkipReplies = 2; // this one, and the next one
else return RedisResult.Create("ERR syntax error", ResultType.Error);
return RedisResult.OK;
}
[RedisCommand(-1)]
protected virtual RedisResult Cluster(RedisClient client, RedisRequest request)
=> CommandNotFound(request.Command);
[RedisCommand(-3)]
protected virtual RedisResult Lpush(RedisClient client, RedisRequest request)
{
var key = request.GetKey(1);
long length = -1;
for (int i = 2; i < request.Count; i++)
{
length = Lpush(client.Database, key, request.GetValue(i));
}
return RedisResult.Create(length, ResultType.Integer);
}
[RedisCommand(-3)]
protected virtual RedisResult Rpush(RedisClient client, RedisRequest request)
{
var key = request.GetKey(1);
long length = -1;
for (int i = 2; i < request.Count; i++)
{
length = Rpush(client.Database, key, request.GetValue(i));
}
return RedisResult.Create(length, ResultType.Integer);
}
[RedisCommand(2)]
protected virtual RedisResult Lpop(RedisClient client, RedisRequest request)
=> RedisResult.Create(Lpop(client.Database, request.GetKey(1)), ResultType.BulkString);
[RedisCommand(2)]
protected virtual RedisResult Rpop(RedisClient client, RedisRequest request)
=> RedisResult.Create(Rpop(client.Database, request.GetKey(1)), ResultType.BulkString);
[RedisCommand(2)]
protected virtual RedisResult Llen(RedisClient client, RedisRequest request)
=> RedisResult.Create(Llen(client.Database, request.GetKey(1)), ResultType.Integer);
protected virtual long Lpush(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
protected virtual long Rpush(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
protected virtual long Llen(int database, RedisKey key) => throw new NotSupportedException();
protected virtual RedisValue Rpop(int database, RedisKey key) => throw new NotSupportedException();
protected virtual RedisValue Lpop(int database, RedisKey key) => throw new NotSupportedException();
[RedisCommand(4)]
protected virtual RedisResult LRange(RedisClient client, RedisRequest request)
{
var key = request.GetKey(1);
long start = request.GetInt64(2), stop = request.GetInt64(3);
var len = Llen(client.Database, key);
if (len == 0) return RedisResult.EmptyArray;
if (start < 0) start = len + start;
if (stop < 0) stop = len + stop;
if (stop < 0 || start >= len || stop < start) return RedisResult.EmptyArray;
if (start < 0) start = 0;
else if (start >= len) start = len - 1;
if (stop < 0) stop = 0;
else if (stop >= len) stop = len - 1;
var arr = new RedisValue[(stop - start) + 1];
LRange(client.Database, key, start, arr);
return RedisResult.Create(arr);
}
protected virtual void LRange(int database, RedisKey key, long start, RedisValue[] arr) => throw new NotSupportedException();
protected virtual void OnUpdateServerConfiguration() { }
protected RedisConfig ServerConfiguration { get; } = RedisConfig.Create();
protected struct RedisConfig
{
internal static RedisConfig Create() => new RedisConfig(
new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase));
internal Dictionary<string, string> Wrapped { get; }
public int Count => Wrapped.Count;
private RedisConfig(Dictionary<string, string> inner) => Wrapped = inner;
public string this[string key]
{
get => Wrapped.TryGetValue(key, out var val) ? val : null;
set
{
if (Wrapped.ContainsKey(key)) Wrapped[key] = value; // no need to fix case
else Wrapped[key.ToLowerInvariant()] = value;
}
}
internal int CountMatch(string pattern)
{
int count = 0;
foreach (var pair in Wrapped)
{
if (IsMatch(pattern, pair.Key)) count++;
}
return count;
}
}
[RedisCommand(3, "config", "get", LockFree = true)]
protected virtual RedisResult Config(RedisClient client, RedisRequest request)
{
var pattern = request.GetString(2);
OnUpdateServerConfiguration();
var config = ServerConfiguration;
var matches = config.CountMatch(pattern);
if (matches == 0) return RedisResult.Create(Array.Empty<RedisResult>());
var arr = new RedisResult[2 * matches];
int index = 0;
foreach (var pair in config.Wrapped)
{
if (IsMatch(pattern, pair.Key))
{
arr[index++] = RedisResult.Create(pair.Key, ResultType.BulkString);
arr[index++] = RedisResult.Create(pair.Value, ResultType.BulkString);
}
}
if (index != arr.Length) throw new InvalidOperationException("Configuration CountMatch fail");
return RedisResult.Create(arr);
}
[RedisCommand(2, LockFree = true)]
protected virtual RedisResult Echo(RedisClient client, RedisRequest request)
=> request.GetResult(1);
[RedisCommand(2)]
protected virtual RedisResult Exists(RedisClient client, RedisRequest request)
{
int count = 0;
var db = client.Database;
for (int i = 1; i < request.Count; i++)
{
if (Exists(db, request.GetKey(i)))
count++;
}
return RedisResult.Create(count, ResultType.Integer);
}
protected virtual bool Exists(int database, RedisKey key)
{
try
{
return !Get(database, key).IsNull;
}
catch (InvalidCastException) { return true; } // to be an invalid cast, it must exist
}
[RedisCommand(2)]
protected virtual RedisResult Get(RedisClient client, RedisRequest request)
=> RedisResult.Create(Get(client.Database, request.GetKey(1)), ResultType.BulkString);
protected virtual RedisValue Get(int database, RedisKey key) => throw new NotSupportedException();
[RedisCommand(3)]
protected virtual RedisResult Set(RedisClient client, RedisRequest request)
{
Set(client.Database, request.GetKey(1), request.GetValue(2));
return RedisResult.OK;
}
protected virtual void Set(int database, RedisKey key, RedisValue value) => throw new NotSupportedException();
[RedisCommand(1)]
protected new virtual RedisResult Shutdown(RedisClient client, RedisRequest request)
{
DoShutdown();
return RedisResult.OK;
}
[RedisCommand(2)]
protected virtual RedisResult Strlen(RedisClient client, RedisRequest request)
=> RedisResult.Create(Strlen(client.Database, request.GetKey(1)), ResultType.Integer);
protected virtual long Strlen(int database, RedisKey key) => Get(database, key).Length();
[RedisCommand(-2)]
protected virtual RedisResult Del(RedisClient client, RedisRequest request)
{
int count = 0;
for (int i = 1; i < request.Count; i++)
{
if (Del(client.Database, request.GetKey(i)))
count++;
}
return RedisResult.Create(count, ResultType.Integer);
}
protected virtual bool Del(int database, RedisKey key) => throw new NotSupportedException();
[RedisCommand(1)]
protected virtual RedisResult Dbsize(RedisClient client, RedisRequest request)
=> RedisResult.Create(Dbsize(client.Database), ResultType.Integer);
protected virtual long Dbsize(int database) => throw new NotSupportedException();
[RedisCommand(1)]
protected virtual RedisResult Flushall(RedisClient client, RedisRequest request)
{
var count = Databases;
for (int i = 0; i < count; i++)
{
Flushdb(i);
}
return RedisResult.OK;
}
[RedisCommand(1)]
protected virtual RedisResult Flushdb(RedisClient client, RedisRequest request)
{
Flushdb(client.Database);
return RedisResult.OK;
}
protected virtual void Flushdb(int database) => throw new NotSupportedException();
[RedisCommand(-1, LockFree = true, MaxArgs = 2)]
protected virtual RedisResult Info(RedisClient client, RedisRequest request)
{
var info = Info(request.Count == 1 ? null : request.GetString(1));
return RedisResult.Create(info, ResultType.BulkString);
}
protected virtual string Info(string selected)
{
var sb = new StringBuilder();
bool IsMatch(string section) => string.IsNullOrWhiteSpace(selected)
|| string.Equals(section, selected, StringComparison.OrdinalIgnoreCase);
if (IsMatch("Server")) Info(sb, "Server");
if (IsMatch("Clients")) Info(sb, "Clients");
if (IsMatch("Memory")) Info(sb, "Memory");
if (IsMatch("Persistence")) Info(sb, "Persistence");
if (IsMatch("Stats")) Info(sb, "Stats");
if (IsMatch("Replication")) Info(sb, "Replication");
if (IsMatch("Keyspace")) Info(sb, "Keyspace");
return sb.ToString();
}
[RedisCommand(2)]
protected virtual RedisResult Keys(RedisClient client, RedisRequest request)
{
List<RedisResult> found = null;
foreach (var key in Keys(client.Database, request.GetKey(1)))
{
if (found == null) found = new List<RedisResult>();
found.Add(RedisResult.Create(key));
}
return RedisResult.Create(
found == null ? Array.Empty<RedisResult>() : found.ToArray());
}
protected virtual IEnumerable<RedisKey> Keys(int database, RedisKey pattern) => throw new NotSupportedException();
protected virtual void Info(StringBuilder sb, string section)
{
StringBuilder AddHeader()
{
if (sb.Length != 0) sb.AppendLine();
return sb.Append("# ").AppendLine(section);
}
switch (section)
{
case "Server":
AddHeader().AppendLine("redis_version:1.0")
.AppendLine("redis_mode:standalone")
.Append("os:").Append(Environment.OSVersion).AppendLine()
.Append("arch_bits:x").Append(IntPtr.Size * 8).AppendLine();
using (var process = Process.GetCurrentProcess())
{
sb.Append("process:").Append(process.Id).AppendLine();
}
var port = TcpPort();
if (port >= 0) sb.Append("tcp_port:").Append(port).AppendLine();
break;
case "Clients":
AddHeader().Append("connected_clients:").Append(ClientCount).AppendLine();
break;
case "Memory":
break;
case "Persistence":
AddHeader().AppendLine("loading:0");
break;
case "Stats":
AddHeader().Append("total_connections_received:").Append(TotalClientCount).AppendLine()
.Append("total_commands_processed:").Append(CommandsProcesed).AppendLine();
break;
case "Replication":
AddHeader().AppendLine("role:master");
break;
case "Keyspace":
break;
}
}
[RedisCommand(2, "memory", "purge")]
protected virtual RedisResult MemoryPurge(RedisClient client, RedisRequest request)
{
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced);
return RedisResult.OK;
}
[RedisCommand(-2)]
protected virtual RedisResult Mget(RedisClient client, RedisRequest request)
{
int argCount = request.Count;
var arr = new RedisResult[argCount - 1];
var db = client.Database;
for (int i = 1; i < argCount; i++)
{
arr[i - 1] = RedisResult.Create(Get(db, request.GetKey(i)), ResultType.BulkString);
}
return RedisResult.Create(arr);
}
[RedisCommand(-3)]
protected virtual RedisResult Mset(RedisClient client, RedisRequest request)
{
int argCount = request.Count;
var db = client.Database;
for (int i = 1; i < argCount;)
{
Set(db, request.GetKey(i++), request.GetValue(i++));
}
return RedisResult.OK;
}
[RedisCommand(-1, LockFree = true, MaxArgs = 2)]
protected virtual RedisResult Ping(RedisClient client, RedisRequest request)
=> RedisResult.Create(request.Count == 1 ? "PONG" : request.GetString(1), ResultType.SimpleString);
[RedisCommand(1, LockFree = true)]
protected virtual RedisResult Quit(RedisClient client, RedisRequest request)
{
RemoveClient(client);
return RedisResult.OK;
}
[RedisCommand(1, LockFree = true)]
protected virtual RedisResult Role(RedisClient client, RedisRequest request)
{
return RedisResult.Create(new[]
{
RedisResult.Create("master", ResultType.BulkString),
RedisResult.Create(0, ResultType.Integer),
RedisResult.Create(Array.Empty<RedisResult>())
});
}
[RedisCommand(2, LockFree = true)]
protected virtual RedisResult Select(RedisClient client, RedisRequest request)
{
var raw = request.GetValue(1);
if (!raw.IsInteger) return RedisResult.Create("ERR invalid DB index", ResultType.Error);
int db = (int)raw;
if (db < 0 || db >= Databases) return RedisResult.Create("ERR DB index is out of range", ResultType.Error);
client.Database = db;
return RedisResult.OK;
}
[RedisCommand(-2)]
protected virtual RedisResult Subscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request);
[RedisCommand(-2)]
protected virtual RedisResult Unsubscribe(RedisClient client, RedisRequest request)
=> SubscribeImpl(client, request);
private RedisResult SubscribeImpl(RedisClient client, RedisRequest request)
{
var reply = new RedisResult[3 * (request.Count - 1)];
int index = 0;
var mode = request.Command[0] == 'p' ? RedisChannel.PatternMode.Pattern : RedisChannel.PatternMode.Literal;
for (int i = 1; i < request.Count; i++)
{
var channel = request.GetChannel(i, mode);
int count;
switch (request.Command)
{
case "subscribe": count = client.Subscribe(channel); break;
case "unsubscribe": count = client.Unsubscribe(channel); break;
default: return null;
}
reply[index++] = RedisResult.Create(request.Command, ResultType.BulkString);
reply[index++] = RedisResult.Create(channel);
reply[index++] = RedisResult.Create(count, ResultType.Integer);
}
return RedisResult.Create(reply);
}
static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
[RedisCommand(1, LockFree = true)]
protected virtual RedisResult Time(RedisClient client, RedisRequest request)
{
var delta = Time() - UnixEpoch;
var ticks = delta.Ticks;
var seconds = ticks / TimeSpan.TicksPerSecond;
var micros = (ticks % TimeSpan.TicksPerSecond) / (TimeSpan.TicksPerMillisecond / 1000);
return RedisResult.Create(new[] {
RedisResult.Create(seconds, ResultType.BulkString),
RedisResult.Create(micros, ResultType.BulkString),
});
}
protected virtual DateTime Time() => DateTime.UtcNow;
[RedisCommand(-2)]
protected virtual RedisResult Unlink(RedisClient client, RedisRequest request)
=> Del(client, request);
[RedisCommand(2)]
protected virtual RedisResult Incr(RedisClient client, RedisRequest request)
=> RedisResult.Create(IncrBy(client.Database, request.GetKey(1), 1), ResultType.Integer);
[RedisCommand(2)]
protected virtual RedisResult Decr(RedisClient client, RedisRequest request)
=> RedisResult.Create(IncrBy(client.Database, request.GetKey(1), -1), ResultType.Integer);
[RedisCommand(3)]
protected virtual RedisResult IncrBy(RedisClient client, RedisRequest request)
=> RedisResult.Create(IncrBy(client.Database, request.GetKey(1), request.GetInt64(2)), ResultType.Integer);
protected virtual long IncrBy(int database, RedisKey key, long delta)
{
var value = ((long)Get(database, key)) + delta;
Set(database, key, value);
return value;
}
}
}

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis.Server
{
public abstract partial class RespServer : IDisposable
{
private readonly List<RedisClient> _clients = new List<RedisClient>();
private readonly TextWriter _output;
private Socket _listener;
public RespServer(TextWriter output = null)
{
_output = output;
_commands = BuildCommands(this);
}
private static Dictionary<string, RespCommand> BuildCommands(RespServer server)
{
RedisCommandAttribute CheckSignatureAndGetAttribute(MethodInfo method)
{
if (method.ReturnType != typeof(RedisResult)) return null;
var p = method.GetParameters();
if (p.Length != 2 || p[0].ParameterType != typeof(RedisClient) || p[1].ParameterType != typeof(RedisRequest))
return null;
return (RedisCommandAttribute)Attribute.GetCustomAttribute(method, typeof(RedisCommandAttribute));
}
var grouped = from method in server.GetType().GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)
let attrib = CheckSignatureAndGetAttribute(method)
where attrib != null
select new RespCommand(attrib, method, server) into cmd
group cmd by cmd.Command;
var result = new Dictionary<string, RespCommand>(StringComparer.OrdinalIgnoreCase);
foreach (var grp in grouped)
{
RespCommand parent;
if (grp.Any(x => x.IsSubCommand))
{
var subs = grp.Where(x => x.IsSubCommand).ToArray();
parent = grp.SingleOrDefault(x => !x.IsSubCommand).WithSubCommands(subs);
}
else
{
parent = grp.Single();
}
result.Add(grp.Key, parent);
}
return result;
}
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
protected sealed class RedisCommandAttribute : Attribute
{
public RedisCommandAttribute(int arity,
string command = null, string subcommand = null)
{
Command = command;
SubCommand = subcommand;
Arity = arity;
MaxArgs = Arity > 0 ? Arity : int.MaxValue;
}
public int MaxArgs { get; set; }
public string Command { get; }
public string SubCommand { get; }
public int Arity { get; }
public bool LockFree { get; set; }
}
private readonly Dictionary<string, RespCommand> _commands;
readonly struct RespCommand
{
public RespCommand(RedisCommandAttribute attrib, MethodInfo method, RespServer server)
{
_operation = (RespOperation)Delegate.CreateDelegate(typeof(RespOperation), server, method);
Command = (string.IsNullOrWhiteSpace(attrib.Command) ? method.Name : attrib.Command).Trim().ToLowerInvariant();
SubCommand = attrib.SubCommand?.Trim()?.ToLowerInvariant();
Arity = attrib.Arity;
MaxArgs = attrib.MaxArgs;
LockFree = attrib.LockFree;
_subcommands = null;
}
public string Command { get; }
public string SubCommand { get; }
public bool IsSubCommand => !string.IsNullOrEmpty(SubCommand);
public int Arity { get; }
public int MaxArgs { get; }
public bool LockFree { get; }
readonly RespOperation _operation;
private readonly RespCommand[] _subcommands;
public bool HasSubCommands => _subcommands != null;
internal RespCommand WithSubCommands(RespCommand[] subs)
=> new RespCommand(this, subs);
private RespCommand(RespCommand parent, RespCommand[] subs)
{
if (parent.IsSubCommand) throw new InvalidOperationException("Cannot have nested sub-commands");
if (parent.HasSubCommands) throw new InvalidOperationException("Already has sub-commands");
if (subs == null || subs.Length == 0) throw new InvalidOperationException("Cannot add empty sub-commands");
Command = parent.Command ?? subs[0].Command;
SubCommand = parent.SubCommand;
Arity = parent.Arity;
MaxArgs = parent.MaxArgs;
LockFree = parent.LockFree;
_operation = parent._operation;
_subcommands = subs;
}
public bool IsUnknown => _operation == null;
public RespCommand Resolve(in RedisRequest request)
{
if (request.Count >= 2)
{
var subs = _subcommands;
if (subs != null)
{
var subcommand = request.GetString(1);
for (int i = 0; i < subs.Length; i++)
{
if (string.Equals(subcommand, subs[i].SubCommand, StringComparison.OrdinalIgnoreCase))
return subs[i];
}
}
}
return this;
}
public RedisResult Execute(RedisClient client, RedisRequest request)
{
var args = request.Count;
if (!CheckArity(request.Count)) return IsSubCommand
? request.UnknownSubcommandOrArgumentCount()
: request.WrongArgCount();
return _operation(client, request);
}
private bool CheckArity(int count)
=> count <= MaxArgs && (Arity <= 0 ? count >= -Arity : count == Arity);
internal int NetArity()
{
if (!HasSubCommands) return Arity;
var minMagnitude = _subcommands.Min(x => Math.Abs(x.Arity));
bool varadic = _subcommands.Any(x => x.Arity <= 0);
if (!IsUnknown)
{
minMagnitude = Math.Min(minMagnitude, Math.Abs(Arity));
if (Arity <= 0) varadic = true;
}
return varadic ? -minMagnitude : minMagnitude;
}
}
delegate RedisResult RespOperation(RedisClient client, RedisRequest request);
protected int TcpPort()
{
var ep = _listener?.LocalEndPoint;
if (ep is IPEndPoint ip) return ip.Port;
if (ep is DnsEndPoint dns) return dns.Port;
return -1;
}
private Action<object> _runClientCallback;
private Action<object> RunClientCallback => _runClientCallback ??
(_runClientCallback = state => RunClient((RedisClient)state));
public void Listen(
EndPoint endpoint,
AddressFamily addressFamily = AddressFamily.InterNetwork,
SocketType socketType = SocketType.Stream,
ProtocolType protocolType = ProtocolType.Tcp,
PipeOptions sendOptions = null, PipeOptions receiveOptions = null)
{
Socket listener = new Socket(addressFamily, socketType, protocolType);
listener.Bind(endpoint);
listener.Listen(20);
_listener = listener;
StartOnScheduler(receiveOptions?.ReaderScheduler, _ => ListenForConnections(
sendOptions ?? PipeOptions.Default, receiveOptions ?? PipeOptions.Default), null);
Log("Server is listening on " + Format.ToString(endpoint));
}
private static void StartOnScheduler(PipeScheduler scheduler, Action<object> callback, object state)
{
if (scheduler == PipeScheduler.Inline) scheduler = null;
(scheduler ?? PipeScheduler.ThreadPool).Schedule(callback, state);
}
// for extensibility, so that a subclass can get their own client type
// to be used via ListenForConnections
protected virtual RedisClient CreateClient() => new RedisClient();
public int ClientCount
{
get { lock (_clients) { return _clients.Count; } }
}
public int TotalClientCount { get; private set; }
public void AddClient(RedisClient client)
{
if (client == null) throw new ArgumentNullException(nameof(client));
lock (_clients)
{
ThrowIfShutdown();
_clients.Add(client);
TotalClientCount++;
}
}
public bool RemoveClient(RedisClient client)
{
if (client == null) throw new ArgumentNullException(nameof(client));
lock (_clients)
{
client.Closed = true;
return _clients.Remove(client);
}
}
private async void ListenForConnections(PipeOptions sendOptions, PipeOptions receiveOptions)
{
try
{
while (true)
{
var client = await _listener.AcceptAsync();
SocketConnection.SetRecommendedServerOptions(client);
var pipe = SocketConnection.Create(client, sendOptions, receiveOptions);
var c = CreateClient();
c.LinkedPipe = pipe;
AddClient(c);
StartOnScheduler(receiveOptions.ReaderScheduler, RunClientCallback, c);
}
}
catch (NullReferenceException) { }
catch (ObjectDisposedException) { }
catch (Exception ex)
{
if(!_isShutdown) Log("Listener faulted: " + ex.Message);
}
}
private readonly TaskCompletionSource<int> _shutdown = new TaskCompletionSource<int>();
private bool _isShutdown;
protected void ThrowIfShutdown()
{
if (_isShutdown) throw new InvalidOperationException("The server is shutting down");
}
protected void DoShutdown(PipeScheduler scheduler = null)
{
if (_isShutdown) return;
Log("Server shutting down...");
_isShutdown = true;
lock (_clients)
{
foreach (var client in _clients) client.Dispose();
_clients.Clear();
}
StartOnScheduler(scheduler,
state => ((TaskCompletionSource<int>)state).TrySetResult(0), _shutdown);
}
public Task Shutdown => _shutdown.Task;
public void Dispose() => Dispose(true);
protected virtual void Dispose(bool disposing)
{
DoShutdown();
var socket = _listener;
if (socket != null)
{
try { socket.Dispose(); } catch { }
}
}
async void RunClient(RedisClient client)
{
ThrowIfShutdown();
var input = client?.LinkedPipe?.Input;
var output = client?.LinkedPipe?.Output;
if (input == null || output == null) return; // nope
Exception fault = null;
try
{
while (!client.Closed)
{
var readResult = await input.ReadAsync();
var buffer = readResult.Buffer;
bool makingProgress = false;
while (!client.Closed && TryProcessRequest(ref buffer, client, output))
{
makingProgress = true;
await output.FlushAsync();
}
input.AdvanceTo(buffer.Start, buffer.End);
if (!makingProgress && readResult.IsCompleted)
{
break;
}
}
}
catch (ConnectionResetException) { }
catch (ObjectDisposedException) { }
catch (Exception ex) { fault = ex; }
finally
{
try { input.Complete(fault); } catch { }
try { output.Complete(fault); } catch { }
if (fault != null && !_isShutdown)
{
Log("Connection faulted (" + fault.GetType().Name + "): " + fault.Message);
}
}
}
private void Log(string message)
{
var output = _output;
if (output != null)
{
lock (output)
{
output.WriteLine(message);
}
}
}
private Encoder _serverEncoder = Encoding.UTF8.GetEncoder();
static Encoder s_sharedEncoder; // swapped in/out to avoid alloc on the public WriteResponse API
public static void WriteResponse(RedisClient client, PipeWriter output, RedisResult response)
{
var enc = Interlocked.Exchange(ref s_sharedEncoder, null) ?? Encoding.UTF8.GetEncoder();
WriteResponse(client, output, response, enc);
Interlocked.Exchange(ref s_sharedEncoder, enc);
}
internal static void WriteResponse(RedisClient client, PipeWriter output, RedisResult response, Encoder encoder)
{
if (response == null) return; // not actually a request (i.e. empty/whitespace request)
if (client != null && client.ShouldSkipResponse()) return; // intentionally skipping the result
char prefix;
switch (response.Type)
{
case ResultType.Integer:
PhysicalConnection.WriteInteger(output, (long)response);
break;
case ResultType.Error:
prefix = '-';
goto BasicMessage;
case ResultType.SimpleString:
prefix = '+';
BasicMessage:
var span = output.GetSpan(1);
span[0] = (byte)prefix;
output.Advance(1);
var val = response.AsString();
var expectedLength = Encoding.UTF8.GetByteCount(val);
PhysicalConnection.WriteRaw(output, val, expectedLength, encoder);
PhysicalConnection.WriteCrlf(output);
break;
case ResultType.BulkString:
PhysicalConnection.WriteBulkString(response.AsRedisValue(), output, encoder);
break;
case ResultType.MultiBulk:
if (response.IsNull)
{
PhysicalConnection.WriteMultiBulkHeader(output, -1);
}
else
{
var arr = (RedisResult[])response;
PhysicalConnection.WriteMultiBulkHeader(output, arr.Length);
for (int i = 0; i < arr.Length; i++)
{
var item = arr[i];
if (item == null)
throw new InvalidOperationException("Array element cannot be null, index " + i);
WriteResponse(null, output, item, encoder); // note: don't pass client down; this would impact SkipReplies
}
}
break;
default:
throw new InvalidOperationException(
"Unexpected result type: " + response.Type);
}
}
public static bool TryParseRequest(ref ReadOnlySequence<byte> buffer, out RedisRequest request)
{
var reader = new BufferReader(buffer);
var raw = PhysicalConnection.TryParseResult(in buffer, ref reader, false, null, true);
if (raw.HasValue)
{
buffer = reader.SliceFromCurrent();
request = new RedisRequest(raw);
return true;
}
request = default;
return false;
}
bool TryProcessRequest(ref ReadOnlySequence<byte> buffer, RedisClient client, PipeWriter output)
{
if (!buffer.IsEmpty && TryParseRequest(ref buffer, out var request))
{
RedisResult response;
try { response = Execute(client, request); }
finally { request.Recycle(); }
WriteResponse(client, output, response, _serverEncoder);
return true;
}
return false;
}
private object ServerSyncLock => this;
private long _commandsProcesed;
public long CommandsProcesed => _commandsProcesed;
public RedisResult Execute(RedisClient client, RedisRequest request)
{
if (string.IsNullOrWhiteSpace(request.Command)) return null; // not a request
Interlocked.Increment(ref _commandsProcesed);
try
{
RedisResult result;
if(_commands.TryGetValue(request.Command, out var cmd))
{
request = request.AsCommand(cmd.Command); // fixup casing
if (cmd.HasSubCommands)
{
cmd = cmd.Resolve(request);
if (cmd.IsUnknown) return request.UnknownSubcommandOrArgumentCount();
}
if(cmd.LockFree)
{
result = cmd.Execute(client, request);
}
else
{
lock(ServerSyncLock)
{
result = cmd.Execute(client, request);
}
}
}
else
{
result = null;
}
if (result == null) Log($"missing command: '{request.Command}'");
return result ?? CommandNotFound(request.Command);
}
catch (NotSupportedException)
{
Log($"missing command: '{request.Command}'");
return CommandNotFound(request.Command);
}
catch (NotImplementedException)
{
Log($"missing command: '{request.Command}'");
return CommandNotFound(request.Command);
}
catch (InvalidCastException)
{
return RedisResult.Create("WRONGTYPE Operation against a key holding the wrong kind of value", ResultType.Error);
}
catch (Exception ex)
{
if(!_isShutdown) Log(ex.Message);
return RedisResult.Create("ERR " + ex.Message, ResultType.Error);
}
}
internal static string ToLower(RawResult value)
{
var val = value.GetString();
if (string.IsNullOrWhiteSpace(val)) return val;
return val.ToLowerInvariant();
}
protected static RedisResult CommandNotFound(string command)
=> RedisResult.Create($"ERR unknown command '{command}'", ResultType.Error);
[RedisCommand(1, LockFree = true)]
protected virtual RedisResult Command(RedisClient client, RedisRequest request)
{
var results = new RedisResult[_commands.Count];
int index = 0;
foreach (var pair in _commands)
results[index++] = CommandInfo(pair.Value);
return RedisResult.Create(results);
}
[RedisCommand(-2, "command", "info", LockFree = true)]
protected virtual RedisResult CommandInfo(RedisClient client, RedisRequest request)
{
var results = new RedisResult[request.Count - 2];
for (int i = 2; i < request.Count; i++)
{
results[i - 2] = _commands.TryGetValue(request.GetString(i), out var cmd)
? CommandInfo(cmd) : null;
}
return RedisResult.Create(results);
}
private RedisResult CommandInfo(RespCommand command)
=> RedisResult.Create(new[]
{
RedisResult.Create(command.Command, ResultType.BulkString),
RedisResult.Create(command.NetArity(), ResultType.Integer),
RedisResult.EmptyArray,
RedisResult.Zero,
RedisResult.Zero,
RedisResult.Zero,
});
}
}
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(LibraryTargetFrameworks)</TargetFrameworks>
<Description>Basic redis server based on StackExchange.Redis</Description>
<AssemblyTitle>StackExchange.Redis</AssemblyTitle>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<AssemblyName>StackExchange.Redis.Server</AssemblyName>
<PackageId>StackExchange.Redis.Server</PackageId>
<PackageTags>Server;Async;Redis;Cache;PubSub;Messaging</PackageTags>
<OutputTypeEx>Library</OutputTypeEx>
<SignAssembly>true</SignAssembly>
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
<LangVersion>latest</LangVersion>
<NoWarn>$(NoWarn);CS1591</NoWarn>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Debug|netstandard2.0|AnyCPU'">
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StackExchange.Redis\StackExchange.Redis.csproj" />
<PackageReference Include="System.Runtime.Caching" Version="4.5.0" />
</ItemGroup>
</Project>
# Wait, what is this?
This is **not** a replacement for redis!
This is some example code that illustrates using "pipelines" to implement a server, in this case a server that works like 'redis',
implementing the same protocol, and offering similar services.
What it isn't:
- supported
- as good as redis
- feature complete
- bug free
What it is:
- useful for me to test my protocol handling
- useful for debugging
- useful for anyone looking for reference code for implementing a custom server based on pipelines
- fun
Example usage:
```c#
using System;
using System.Net;
using System.Threading.Tasks;
using StackExchange.Redis.Server;
static class Program
{
static async Task Main()
{
using (var server = new MemoryCacheRedisServer(Console.Out))
{
server.Listen(new IPEndPoint(IPAddress.Loopback, 6379));
await server.Shutdown;
}
}
}
```
\ No newline at end of file
...@@ -677,11 +677,11 @@ public void MovedProfiling() ...@@ -677,11 +677,11 @@ public void MovedProfiling()
var Key = Me(); var Key = Me();
const string Value = "redirected-value"; const string Value = "redirected-value";
var profiler = new ProfilingSession(); var profiler = new Profiling.PerThreadProfiler();
using (var conn = Create()) using (var conn = Create())
{ {
conn.RegisterProfiler(() => profiler); conn.RegisterProfiler(profiler.GetSession);
var endpoints = conn.GetEndPoints(); var endpoints = conn.GetEndPoints();
var servers = endpoints.Select(e => conn.GetServer(e)); var servers = endpoints.Select(e => conn.GetServer(e));
...@@ -705,7 +705,7 @@ public void MovedProfiling() ...@@ -705,7 +705,7 @@ public void MovedProfiling()
string b = (string)conn.GetServer(wrongMasterNode.EndPoint).Execute("GET", Key); string b = (string)conn.GetServer(wrongMasterNode.EndPoint).Execute("GET", Key);
Assert.Equal(Value, b); // wrong master, allow redirect Assert.Equal(Value, b); // wrong master, allow redirect
var msgs = profiler.FinishProfiling().ToList(); var msgs = profiler.GetSession().FinishProfiling().ToList();
// verify that things actually got recorded properly, and the retransmission profilings are connected as expected // verify that things actually got recorded properly, and the retransmission profilings are connected as expected
{ {
......
...@@ -202,13 +202,13 @@ public void ManyContexts() ...@@ -202,13 +202,13 @@ public void ManyContexts()
} }
} }
private class PerThreadProfiler internal class PerThreadProfiler
{ {
ThreadLocal<ProfilingSession> perThreadSession = new ThreadLocal<ProfilingSession>(() => new ProfilingSession()); ThreadLocal<ProfilingSession> perThreadSession = new ThreadLocal<ProfilingSession>(() => new ProfilingSession());
public ProfilingSession GetSession() => perThreadSession.Value; public ProfilingSession GetSession() => perThreadSession.Value;
} }
private class AsyncLocalProfiler internal class AsyncLocalProfiler
{ {
AsyncLocal<ProfilingSession> perThreadSession = new AsyncLocal<ProfilingSession>(); AsyncLocal<ProfilingSession> perThreadSession = new AsyncLocal<ProfilingSession>();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
<GenerateDocumentationFile>false</GenerateDocumentationFile> <GenerateDocumentationFile>false</GenerateDocumentationFile>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
<DebugType>full</DebugType> <DebugType>full</DebugType>
<LangVersion>latest</LangVersion>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
......
...@@ -79,6 +79,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Failover", "Failover", "{D0 ...@@ -79,6 +79,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Failover", "Failover", "{D0
RedisConfigs\Failover\slave-6383.conf = RedisConfigs\Failover\slave-6383.conf RedisConfigs\Failover\slave-6383.conf = RedisConfigs\Failover\slave-6383.conf
EndProjectSection EndProjectSection
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis.Server", "StackExchange.Redis.Server\StackExchange.Redis.Server.csproj", "{8375813E-FBAF-4DA3-A2C7-E4645B39B931}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU Debug|Any CPU = Debug|Any CPU
...@@ -117,6 +119,10 @@ Global ...@@ -117,6 +119,10 @@ Global
{769640F3-889C-4E8A-A7DF-916AE9B432A6}.Debug|Any CPU.Build.0 = Debug|Any CPU {769640F3-889C-4E8A-A7DF-916AE9B432A6}.Debug|Any CPU.Build.0 = Debug|Any CPU
{769640F3-889C-4E8A-A7DF-916AE9B432A6}.Release|Any CPU.ActiveCfg = Release|Any CPU {769640F3-889C-4E8A-A7DF-916AE9B432A6}.Release|Any CPU.ActiveCfg = Release|Any CPU
{769640F3-889C-4E8A-A7DF-916AE9B432A6}.Release|Any CPU.Build.0 = Release|Any CPU {769640F3-889C-4E8A-A7DF-916AE9B432A6}.Release|Any CPU.Build.0 = Release|Any CPU
{8375813E-FBAF-4DA3-A2C7-E4645B39B931}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8375813E-FBAF-4DA3-A2C7-E4645B39B931}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8375813E-FBAF-4DA3-A2C7-E4645B39B931}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8375813E-FBAF-4DA3-A2C7-E4645B39B931}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE
......
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
// your version numbers. Therefore, we need to move the attribute out into another file...this file. // your version numbers. Therefore, we need to move the attribute out into another file...this file.
// When .csproj merges in, this should be able to return to Properties/AssemblyInfo.cs // When .csproj merges in, this should be able to return to Properties/AssemblyInfo.cs
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("StackExchange.Redis.Server, PublicKey=00240000048000009400000006020000002400005253413100040000010001007791a689e9d8950b44a9a8886baad2ea180e7a8a854f158c9b98345ca5009cdd2362c84f368f1c3658c132b3c0f74e44ff16aeb2e5b353b6e0fe02f923a050470caeac2bde47a2238a9c7125ed7dab14f486a5a64558df96640933b9f2b6db188fc4a820f96dce963b662fa8864adbff38e5b4542343f162ecdc6dad16912fff")]
[assembly: InternalsVisibleTo("StackExchange.Redis.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001007791a689e9d8950b44a9a8886baad2ea180e7a8a854f158c9b98345ca5009cdd2362c84f368f1c3658c132b3c0f74e44ff16aeb2e5b353b6e0fe02f923a050470caeac2bde47a2238a9c7125ed7dab14f486a5a64558df96640933b9f2b6db188fc4a820f96dce963b662fa8864adbff38e5b4542343f162ecdc6dad16912fff")] [assembly: InternalsVisibleTo("StackExchange.Redis.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001007791a689e9d8950b44a9a8886baad2ea180e7a8a854f158c9b98345ca5009cdd2362c84f368f1c3658c132b3c0f74e44ff16aeb2e5b353b6e0fe02f923a050470caeac2bde47a2238a9c7125ed7dab14f486a5a64558df96640933b9f2b6db188fc4a820f96dce963b662fa8864adbff38e5b4542343f162ecdc6dad16912fff")]
using System;
using System.Buffers;
using System.IO;
namespace StackExchange.Redis
{
internal enum ConsumeResult
{
Failure,
Success,
NeedMoreData,
}
internal ref struct BufferReader
{
private ReadOnlySequence<byte>.Enumerator _iterator;
private ReadOnlySpan<byte> _current;
public ReadOnlySpan<byte> OversizedSpan => _current;
public ReadOnlySpan<byte> SlicedSpan => _current.Slice(OffsetThisSpan, RemainingThisSpan);
public int OffsetThisSpan { get; private set; }
private int TotalConsumed { get; set; } // hide this; callers should use the snapshot-aware methods instead
public int RemainingThisSpan { get; private set; }
public bool IsEmpty => RemainingThisSpan == 0;
private bool FetchNextSegment()
{
do
{
if (!_iterator.MoveNext())
{
OffsetThisSpan = RemainingThisSpan = 0;
return false;
}
_current = _iterator.Current.Span;
OffsetThisSpan = 0;
RemainingThisSpan = _current.Length;
} while (IsEmpty); // skip empty segments, they don't help us!
return true;
}
public BufferReader(ReadOnlySequence<byte> buffer)
{
_buffer = buffer;
_lastSnapshotPosition = buffer.Start;
_lastSnapshotBytes = 0;
_iterator = buffer.GetEnumerator();
_current = default;
OffsetThisSpan = RemainingThisSpan = TotalConsumed = 0;
FetchNextSegment();
}
private static readonly byte[] CRLF = { (byte)'\r', (byte)'\n' };
/// <summary>
/// Note that in results other than success, no guarantees are made about final state; if you care: snapshot
/// </summary>
public ConsumeResult TryConsumeCRLF()
{
switch (RemainingThisSpan)
{
case 0:
return ConsumeResult.NeedMoreData;
case 1:
if (_current[OffsetThisSpan] != (byte)'\r') return ConsumeResult.Failure;
Consume(1);
if (IsEmpty) return ConsumeResult.NeedMoreData;
var next = _current[OffsetThisSpan];
Consume(1);
return next == '\n' ? ConsumeResult.Success : ConsumeResult.Failure;
default:
var offset = OffsetThisSpan;
var result = _current[offset++] == (byte)'\r' && _current[offset] == (byte)'\n'
? ConsumeResult.Success : ConsumeResult.Failure;
Consume(2);
return result;
}
}
public bool TryConsume(int count)
{
if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
do
{
var available = RemainingThisSpan;
if (count <= available)
{
// consume part of this span
TotalConsumed += count;
RemainingThisSpan -= count;
OffsetThisSpan += count;
if (count == available) FetchNextSegment(); // burned all of it; fetch next
return true;
}
// consume all of this span
TotalConsumed += available;
count -= available;
} while (FetchNextSegment());
return false;
}
private readonly ReadOnlySequence<byte> _buffer;
private SequencePosition _lastSnapshotPosition;
private long _lastSnapshotBytes;
// makes an internal note of where we are, as a SequencePosition; useful
// to avoid having to use buffer.Slice on huge ranges
private SequencePosition SnapshotPosition()
{
var consumed = TotalConsumed;
var delta = consumed - _lastSnapshotBytes;
if (delta == 0) return _lastSnapshotPosition;
var pos = _buffer.GetPosition(delta, _lastSnapshotPosition);
_lastSnapshotBytes = consumed;
return _lastSnapshotPosition = pos;
}
public ReadOnlySequence<byte> ConsumeAsBuffer(int count)
{
if (!TryConsumeAsBuffer(count, out var buffer)) throw new EndOfStreamException();
return buffer;
}
public ReadOnlySequence<byte> ConsumeToEnd()
{
var from = SnapshotPosition();
var result = _buffer.Slice(from);
while (FetchNextSegment()) { } // consume all
return result;
}
public bool TryConsumeAsBuffer(int count, out ReadOnlySequence<byte> buffer)
{
var from = SnapshotPosition();
if (!TryConsume(count))
{
buffer = default;
return false;
}
var to = SnapshotPosition();
buffer = _buffer.Slice(from, to);
return true;
}
public void Consume(int count)
{
if (!TryConsume(count)) throw new EndOfStreamException();
}
internal static int FindNext(BufferReader reader, byte value) // very deliberately not ref; want snapshot
{
int totalSkipped = 0;
do
{
if (reader.RemainingThisSpan == 0) continue;
var span = reader.SlicedSpan;
int found = span.IndexOf(value);
if (found >= 0) return totalSkipped + found;
totalSkipped += span.Length;
} while (reader.FetchNextSegment());
return -1;
}
internal static int FindNextCrLf(BufferReader reader) // very deliberately not ref; want snapshot
{
// is it in the current span? (we need to handle the offsets differently if so)
int totalSkipped = 0;
bool haveTrailingCR = false;
do
{
if (reader.RemainingThisSpan == 0) continue;
var span = reader.SlicedSpan;
if (haveTrailingCR)
{
if (span[0] == '\n') return totalSkipped - 1;
haveTrailingCR = false;
}
int found = span.IndexOf(CRLF);
if (found >= 0) return totalSkipped + found;
haveTrailingCR = span[span.Length - 1] == '\r';
totalSkipped += span.Length;
}
while (reader.FetchNextSegment());
return -1;
}
//internal static bool HasBytes(BufferReader reader, int count) // very deliberately not ref; want snapshot
//{
// if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
// do
// {
// var available = reader.RemainingThisSpan;
// if (count <= available) return true;
// count -= available;
// } while (reader.FetchNextSegment());
// return false;
//}
public int ConsumeByte()
{
if (IsEmpty) return -1;
var value = _current[OffsetThisSpan];
Consume(1);
return value;
}
public int PeekByte() => IsEmpty ? -1 : _current[OffsetThisSpan];
public ReadOnlySequence<byte> SliceFromCurrent()
{
var from = SnapshotPosition();
return _buffer.Slice(from);
}
}
}
...@@ -308,7 +308,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -308,7 +308,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(command, 2); physical.WriteHeader(command, 2);
physical.Write(Key); physical.Write(Key);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
} }
......
...@@ -756,7 +756,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -756,7 +756,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(Channel); physical.Write(Channel);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
...@@ -857,7 +857,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -857,7 +857,7 @@ protected override void WriteImpl(PhysicalConnection physical)
physical.WriteHeader(Command, 3); physical.WriteHeader(Command, 3);
physical.Write(Key); physical.Write(Key);
physical.Write(key1); physical.Write(key1);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
...@@ -889,7 +889,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -889,7 +889,7 @@ protected override void WriteImpl(PhysicalConnection physical)
physical.WriteHeader(command, values.Length); physical.WriteHeader(command, values.Length);
for (int i = 0; i < values.Length; i++) for (int i = 0; i < values.Length; i++)
{ {
physical.Write(values[i]); physical.WriteBulkString(values[i]);
} }
} }
} }
...@@ -939,7 +939,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -939,7 +939,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(Key); physical.Write(Key);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
...@@ -968,7 +968,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -968,7 +968,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, values.Length + 2); physical.WriteHeader(Command, values.Length + 2);
physical.Write(Key); physical.Write(Key);
for (int i = 0; i < values.Length; i++) physical.Write(values[i]); for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]);
physical.Write(key1); physical.Write(key1);
} }
} }
...@@ -989,7 +989,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -989,7 +989,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, values.Length + 1); physical.WriteHeader(Command, values.Length + 1);
physical.Write(Key); physical.Write(Key);
for (int i = 0; i < values.Length; i++) physical.Write(values[i]); for (int i = 0; i < values.Length; i++) physical.WriteBulkString(values[i]);
} }
} }
...@@ -1008,8 +1008,8 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -1008,8 +1008,8 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 3); physical.WriteHeader(Command, 3);
physical.Write(Key); physical.Write(Key);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
} }
} }
...@@ -1030,9 +1030,9 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -1030,9 +1030,9 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 4); physical.WriteHeader(Command, 4);
physical.Write(Key); physical.Write(Key);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
physical.Write(value2); physical.WriteBulkString(value2);
} }
} }
...@@ -1055,10 +1055,10 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -1055,10 +1055,10 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 5); physical.WriteHeader(Command, 5);
physical.Write(Key); physical.Write(Key);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
physical.Write(value2); physical.WriteBulkString(value2);
physical.Write(value3); physical.WriteBulkString(value3);
} }
} }
...@@ -1097,7 +1097,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -1097,7 +1097,7 @@ protected override void WriteImpl(PhysicalConnection physical)
physical.WriteHeader(command, values.Length); physical.WriteHeader(command, values.Length);
for (int i = 0; i < values.Length; i++) for (int i = 0; i < values.Length; i++)
{ {
physical.Write(values[i]); physical.WriteBulkString(values[i]);
} }
} }
} }
...@@ -1114,7 +1114,7 @@ public CommandValueChannelMessage(int db, CommandFlags flags, RedisCommand comma ...@@ -1114,7 +1114,7 @@ public CommandValueChannelMessage(int db, CommandFlags flags, RedisCommand comma
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(value); physical.WriteBulkString(value);
physical.Write(Channel); physical.Write(Channel);
} }
} }
...@@ -1138,7 +1138,7 @@ public override void AppendStormLog(StringBuilder sb) ...@@ -1138,7 +1138,7 @@ public override void AppendStormLog(StringBuilder sb)
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(value); physical.WriteBulkString(value);
physical.Write(Key); physical.Write(Key);
} }
} }
...@@ -1155,7 +1155,7 @@ public CommandValueMessage(int db, CommandFlags flags, RedisCommand command, Red ...@@ -1155,7 +1155,7 @@ public CommandValueMessage(int db, CommandFlags flags, RedisCommand command, Red
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 1); physical.WriteHeader(Command, 1);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
...@@ -1173,8 +1173,8 @@ public CommandValueValueMessage(int db, CommandFlags flags, RedisCommand command ...@@ -1173,8 +1173,8 @@ public CommandValueValueMessage(int db, CommandFlags flags, RedisCommand command
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
} }
} }
...@@ -1194,9 +1194,9 @@ public CommandValueValueValueMessage(int db, CommandFlags flags, RedisCommand co ...@@ -1194,9 +1194,9 @@ public CommandValueValueValueMessage(int db, CommandFlags flags, RedisCommand co
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 3); physical.WriteHeader(Command, 3);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
physical.Write(value2); physical.WriteBulkString(value2);
} }
} }
...@@ -1220,11 +1220,11 @@ public CommandValueValueValueValueValueMessage(int db, CommandFlags flags, Redis ...@@ -1220,11 +1220,11 @@ public CommandValueValueValueValueValueMessage(int db, CommandFlags flags, Redis
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 5); physical.WriteHeader(Command, 5);
physical.Write(value0); physical.WriteBulkString(value0);
physical.Write(value1); physical.WriteBulkString(value1);
physical.Write(value2); physical.WriteBulkString(value2);
physical.Write(value3); physical.WriteBulkString(value3);
physical.Write(value4); physical.WriteBulkString(value4);
} }
} }
...@@ -1237,7 +1237,7 @@ public SelectMessage(int db, CommandFlags flags) : base(db, flags, RedisCommand. ...@@ -1237,7 +1237,7 @@ public SelectMessage(int db, CommandFlags flags) : base(db, flags, RedisCommand.
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 1); physical.WriteHeader(Command, 1);
physical.Write(Db); physical.WriteBulkString(Db);
} }
} }
} }
......
...@@ -581,7 +581,7 @@ internal void Write(RedisKey key) ...@@ -581,7 +581,7 @@ internal void Write(RedisKey key)
var val = key.KeyValue; var val = key.KeyValue;
if (val is string) if (val is string)
{ {
WriteUnified(_ioPipe.Output, key.KeyPrefix, (string)val); WriteUnified(_ioPipe.Output, key.KeyPrefix, (string)val, outEncoder);
} }
else else
{ {
...@@ -590,26 +590,27 @@ internal void Write(RedisKey key) ...@@ -590,26 +590,27 @@ internal void Write(RedisKey key)
} }
internal void Write(RedisChannel channel) internal void Write(RedisChannel channel)
{ => WriteUnified(_ioPipe.Output, ChannelPrefix, channel.Value);
WriteUnified(_ioPipe.Output, ChannelPrefix, channel.Value);
}
internal void Write(RedisValue value) [MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void WriteBulkString(RedisValue value)
=> WriteBulkString(value, _ioPipe.Output, outEncoder);
internal static void WriteBulkString(RedisValue value, PipeWriter output, Encoder outEncoder)
{ {
switch (value.Type) switch (value.Type)
{ {
case RedisValue.StorageType.Null: case RedisValue.StorageType.Null:
WriteUnified(_ioPipe.Output, (byte[])null); WriteUnified(output, (byte[])null);
break; break;
case RedisValue.StorageType.Int64: case RedisValue.StorageType.Int64:
WriteUnified(_ioPipe.Output, (long)value); WriteUnified(output, (long)value);
break; break;
case RedisValue.StorageType.Double: // use string case RedisValue.StorageType.Double: // use string
case RedisValue.StorageType.String: case RedisValue.StorageType.String:
WriteUnified(_ioPipe.Output, null, (string)value); WriteUnified(output, null, (string)value, outEncoder);
break; break;
case RedisValue.StorageType.Raw: case RedisValue.StorageType.Raw:
WriteUnified(_ioPipe.Output, ((ReadOnlyMemory<byte>)value).Span); WriteUnified(output, ((ReadOnlyMemory<byte>)value).Span);
break; break;
default: default:
throw new InvalidOperationException($"Unexpected {value.Type} value: '{value}'"); throw new InvalidOperationException($"Unexpected {value.Type} value: '{value}'");
...@@ -660,13 +661,21 @@ private void WriteHeader(byte[] commandBytes, int arguments) ...@@ -660,13 +661,21 @@ private void WriteHeader(byte[] commandBytes, int arguments)
_ioPipe.Output.Advance(offset); _ioPipe.Output.Advance(offset);
} }
internal static void WriteMultiBulkHeader(PipeWriter output, long count)
{
// *{count}\r\n = 3 + MaxInt32TextLen
var span = output.GetSpan(3 + MaxInt32TextLen);
span[0] = (byte)'*';
int offset = WriteRaw(span, count, offset: 1);
output.Advance(offset);
}
internal const int internal const int
MaxInt32TextLen = 11, // -2,147,483,648 (not including the commas) MaxInt32TextLen = 11, // -2,147,483,648 (not including the commas)
MaxInt64TextLen = 20; // -9,223,372,036,854,775,808 (not including the commas) MaxInt64TextLen = 20; // -9,223,372,036,854,775,808 (not including the commas)
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int WriteCrlf(Span<byte> span, int offset) internal static int WriteCrlf(Span<byte> span, int offset)
{ {
span[offset++] = (byte)'\r'; span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n'; span[offset++] = (byte)'\n';
...@@ -674,7 +683,7 @@ private static int WriteCrlf(Span<byte> span, int offset) ...@@ -674,7 +683,7 @@ private static int WriteCrlf(Span<byte> span, int offset)
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void WriteCrlf(PipeWriter writer) internal static void WriteCrlf(PipeWriter writer)
{ {
var span = writer.GetSpan(2); var span = writer.GetSpan(2);
span[0] = (byte)'\r'; span[0] = (byte)'\r';
...@@ -902,7 +911,7 @@ internal static byte ToHexNibble(int value) ...@@ -902,7 +911,7 @@ internal static byte ToHexNibble(int value)
return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value);
} }
private void WriteUnified(PipeWriter writer, byte[] prefix, string value) internal static void WriteUnified(PipeWriter writer, byte[] prefix, string value, Encoder outEncoder)
{ {
if (value == null) if (value == null)
{ {
...@@ -930,13 +939,13 @@ private void WriteUnified(PipeWriter writer, byte[] prefix, string value) ...@@ -930,13 +939,13 @@ private void WriteUnified(PipeWriter writer, byte[] prefix, string value)
writer.Advance(bytes); writer.Advance(bytes);
if (prefixLength != 0) writer.Write(prefix); if (prefixLength != 0) writer.Write(prefix);
if (encodedLength != 0) WriteRaw(writer, value, encodedLength); if (encodedLength != 0) WriteRaw(writer, value, encodedLength, outEncoder);
WriteCrlf(writer); WriteCrlf(writer);
} }
} }
} }
private unsafe void WriteRaw(PipeWriter writer, string value, int expectedLength) unsafe static internal void WriteRaw(PipeWriter writer, string value, int expectedLength, Encoder outEncoder)
{ {
const int MaxQuickEncodeSize = 512; const int MaxQuickEncodeSize = 512;
...@@ -1029,6 +1038,17 @@ private static void WriteUnified(PipeWriter writer, long value) ...@@ -1029,6 +1038,17 @@ private static void WriteUnified(PipeWriter writer, long value)
var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1); var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1);
writer.Advance(bytes); writer.Advance(bytes);
} }
internal static void WriteInteger(PipeWriter writer, long value)
{
//note: client should never write integer; only server does this
// :{asc}\r\n = MaxInt64TextLen + 3
var span = writer.GetSpan(3 + MaxInt64TextLen);
span[0] = (byte)':';
var bytes = WriteRaw(span, value, withLengthPrefix: false, offset: 1);
writer.Advance(bytes);
}
internal int GetAvailableInboundBytes() => _socket?.Available ?? -1; internal int GetAvailableInboundBytes() => _socket?.Available ?? -1;
...@@ -1288,7 +1308,7 @@ private int ProcessBuffer(ref ReadOnlySequence<byte> buffer) ...@@ -1288,7 +1308,7 @@ private int ProcessBuffer(ref ReadOnlySequence<byte> buffer)
while (!buffer.IsEmpty) while (!buffer.IsEmpty)
{ {
var reader = new BufferReader(buffer); var reader = new BufferReader(buffer);
var result = TryParseResult(in buffer, ref reader); var result = TryParseResult(in buffer, ref reader, IncludeDetailInExceptions, BridgeCouldBeNull?.ServerEndPoint);
try try
{ {
if (result.HasValue) if (result.HasValue)
...@@ -1336,12 +1356,12 @@ private int ProcessBuffer(ref ReadOnlySequence<byte> buffer) ...@@ -1336,12 +1356,12 @@ private int ProcessBuffer(ref ReadOnlySequence<byte> buffer)
// } // }
//} //}
private RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader reader) private static RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader reader, bool includeDetailInExceptions, ServerEndPoint server)
{ {
var itemCount = ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader); var itemCount = ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader);
if (itemCount.HasValue) if (itemCount.HasValue)
{ {
if (!itemCount.TryGetInt64(out long i64)) throw ExceptionFactory.ConnectionFailure(IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid array length", BridgeCouldBeNull?.ServerEndPoint); if (!itemCount.TryGetInt64(out long i64)) throw ExceptionFactory.ConnectionFailure(includeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid array length", server);
int itemCountActual = checked((int)i64); int itemCountActual = checked((int)i64);
if (itemCountActual < 0) if (itemCountActual < 0)
...@@ -1359,7 +1379,7 @@ private RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader r ...@@ -1359,7 +1379,7 @@ private RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader r
var result = new RawResult(oversized, itemCountActual); var result = new RawResult(oversized, itemCountActual);
for (int i = 0; i < itemCountActual; i++) for (int i = 0; i < itemCountActual; i++)
{ {
if (!(oversized[i] = TryParseResult(in buffer, ref reader)).HasValue) if (!(oversized[i] = TryParseResult(in buffer, ref reader, includeDetailInExceptions, server)).HasValue)
{ {
result.Recycle(i); // passing index here means we don't need to "Array.Clear" before-hand result.Recycle(i); // passing index here means we don't need to "Array.Clear" before-hand
return RawResult.Nil; return RawResult.Nil;
...@@ -1370,12 +1390,12 @@ private RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader r ...@@ -1370,12 +1390,12 @@ private RawResult ReadArray(in ReadOnlySequence<byte> buffer, ref BufferReader r
return RawResult.Nil; return RawResult.Nil;
} }
private RawResult ReadBulkString(in ReadOnlySequence<byte> buffer, ref BufferReader reader) private static RawResult ReadBulkString(in ReadOnlySequence<byte> buffer, ref BufferReader reader, bool includeDetailInExceptions, ServerEndPoint server)
{ {
var prefix = ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader); var prefix = ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader);
if (prefix.HasValue) if (prefix.HasValue)
{ {
if (!prefix.TryGetInt64(out long i64)) throw ExceptionFactory.ConnectionFailure(IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string length", BridgeCouldBeNull?.ServerEndPoint); if (!prefix.TryGetInt64(out long i64)) throw ExceptionFactory.ConnectionFailure(includeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string length", server);
int bodySize = checked((int)i64); int bodySize = checked((int)i64);
if (bodySize < 0) if (bodySize < 0)
{ {
...@@ -1391,14 +1411,14 @@ private RawResult ReadBulkString(in ReadOnlySequence<byte> buffer, ref BufferRea ...@@ -1391,14 +1411,14 @@ private RawResult ReadBulkString(in ReadOnlySequence<byte> buffer, ref BufferRea
case ConsumeResult.Success: case ConsumeResult.Success:
return new RawResult(ResultType.BulkString, payload, false); return new RawResult(ResultType.BulkString, payload, false);
default: default:
throw ExceptionFactory.ConnectionFailure(IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string terminator", BridgeCouldBeNull?.ServerEndPoint); throw ExceptionFactory.ConnectionFailure(includeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string terminator", server);
} }
} }
} }
return RawResult.Nil; return RawResult.Nil;
} }
private RawResult ReadLineTerminatedString(ResultType type, in ReadOnlySequence<byte> buffer, ref BufferReader reader) private static RawResult ReadLineTerminatedString(ResultType type, in ReadOnlySequence<byte> buffer, ref BufferReader reader)
{ {
int crlfOffsetFromCurrent = BufferReader.FindNextCrLf(reader); int crlfOffsetFromCurrent = BufferReader.FindNextCrLf(reader);
if (crlfOffsetFromCurrent < 0) return RawResult.Nil; if (crlfOffsetFromCurrent < 0) return RawResult.Nil;
...@@ -1411,219 +1431,48 @@ private RawResult ReadLineTerminatedString(ResultType type, in ReadOnlySequence< ...@@ -1411,219 +1431,48 @@ private RawResult ReadLineTerminatedString(ResultType type, in ReadOnlySequence<
internal void StartReading() => ReadFromPipe(); internal void StartReading() => ReadFromPipe();
private RawResult TryParseResult(in ReadOnlySequence<byte> buffer, ref BufferReader reader) internal static RawResult TryParseResult(in ReadOnlySequence<byte> buffer, ref BufferReader reader,
bool includeDetilInExceptions, ServerEndPoint server, bool allowInlineProtocol = false)
{ {
var prefix = reader.ConsumeByte(); var prefix = reader.PeekByte();
if (prefix < 0) return RawResult.Nil; // EOF if (prefix < 0) return RawResult.Nil; // EOF
switch (prefix) switch (prefix)
{ {
case '+': // simple string case '+': // simple string
reader.Consume(1);
return ReadLineTerminatedString(ResultType.SimpleString, in buffer, ref reader); return ReadLineTerminatedString(ResultType.SimpleString, in buffer, ref reader);
case '-': // error case '-': // error
reader.Consume(1);
return ReadLineTerminatedString(ResultType.Error, in buffer, ref reader); return ReadLineTerminatedString(ResultType.Error, in buffer, ref reader);
case ':': // integer case ':': // integer
reader.Consume(1);
return ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader); return ReadLineTerminatedString(ResultType.Integer, in buffer, ref reader);
case '$': // bulk string case '$': // bulk string
return ReadBulkString(in buffer, ref reader); reader.Consume(1);
return ReadBulkString(in buffer, ref reader, includeDetilInExceptions, server);
case '*': // array case '*': // array
return ReadArray(in buffer, ref reader); reader.Consume(1);
return ReadArray(in buffer, ref reader, includeDetilInExceptions, server);
default: default:
if (allowInlineProtocol) return ParseInlineProtocol(ReadLineTerminatedString(ResultType.SimpleString, in buffer, ref reader));
throw new InvalidOperationException("Unexpected response prefix: " + (char)prefix); throw new InvalidOperationException("Unexpected response prefix: " + (char)prefix);
} }
} }
static RawResult ParseInlineProtocol(RawResult line)
public enum ConsumeResult
{
Failure,
Success,
NeedMoreData,
}
private ref struct BufferReader
{ {
private ReadOnlySequence<byte>.Enumerator _iterator; if (!line.HasValue) return RawResult.Nil; // incomplete line
private ReadOnlySpan<byte> _current;
public ReadOnlySpan<byte> OversizedSpan => _current;
public ReadOnlySpan<byte> SlicedSpan => _current.Slice(OffsetThisSpan, RemainingThisSpan);
public int OffsetThisSpan { get; private set; }
private int TotalConsumed { get; set; } // hide this; callers should use the snapshot-aware methods instead
public int RemainingThisSpan { get; private set; }
public bool IsEmpty => RemainingThisSpan == 0;
private bool FetchNextSegment()
{
do
{
if (!_iterator.MoveNext())
{
OffsetThisSpan = RemainingThisSpan = 0;
return false;
}
_current = _iterator.Current.Span;
OffsetThisSpan = 0;
RemainingThisSpan = _current.Length;
} while (IsEmpty); // skip empty segments, they don't help us!
return true;
}
public BufferReader(ReadOnlySequence<byte> buffer)
{
_buffer = buffer;
_lastSnapshotPosition = buffer.Start;
_lastSnapshotBytes = 0;
_iterator = buffer.GetEnumerator();
_current = default;
OffsetThisSpan = RemainingThisSpan = TotalConsumed = 0;
FetchNextSegment();
}
private static readonly byte[] CRLF = { (byte)'\r', (byte)'\n' };
/// <summary>
/// Note that in results other than success, no guarantees are made about final state; if you care: snapshot
/// </summary>
public ConsumeResult TryConsumeCRLF()
{
switch (RemainingThisSpan)
{
case 0:
return ConsumeResult.NeedMoreData;
case 1:
if (_current[OffsetThisSpan] != (byte)'\r') return ConsumeResult.Failure;
Consume(1);
if (IsEmpty) return ConsumeResult.NeedMoreData;
var next = _current[OffsetThisSpan];
Consume(1);
return next == '\n' ? ConsumeResult.Success : ConsumeResult.Failure;
default:
var offset = OffsetThisSpan;
var result = _current[offset++] == (byte)'\r' && _current[offset] == (byte)'\n'
? ConsumeResult.Success : ConsumeResult.Failure;
Consume(2);
return result;
}
}
public bool TryConsume(int count)
{
if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
do
{
var available = RemainingThisSpan;
if (count <= available)
{
// consume part of this span
TotalConsumed += count;
RemainingThisSpan -= count;
OffsetThisSpan += count;
if (count == available) FetchNextSegment(); // burned all of it; fetch next
return true;
}
// consume all of this span
TotalConsumed += available;
count -= available;
} while (FetchNextSegment());
return false;
}
private readonly ReadOnlySequence<byte> _buffer;
private SequencePosition _lastSnapshotPosition;
private long _lastSnapshotBytes;
// makes an internal note of where we are, as a SequencePosition; useful
// to avoid having to use buffer.Slice on huge ranges
private SequencePosition SnapshotPosition()
{
var consumed = TotalConsumed;
var delta = consumed - _lastSnapshotBytes;
if (delta == 0) return _lastSnapshotPosition;
var pos = _buffer.GetPosition(delta, _lastSnapshotPosition);
_lastSnapshotBytes = consumed;
return _lastSnapshotPosition = pos;
}
public ReadOnlySequence<byte> ConsumeAsBuffer(int count)
{
if (!TryConsumeAsBuffer(count, out var buffer)) throw new EndOfStreamException();
return buffer;
}
public bool TryConsumeAsBuffer(int count, out ReadOnlySequence<byte> buffer)
{
var from = SnapshotPosition();
if (!TryConsume(count))
{
buffer = default;
return false;
}
var to = SnapshotPosition();
buffer = _buffer.Slice(from, to);
return true;
}
public void Consume(int count)
{
if (!TryConsume(count)) throw new EndOfStreamException();
}
internal static int FindNextCrLf(BufferReader reader) // very deliberately not ref; want snapshot int count = 0;
foreach (var token in line.GetInlineTokenizer()) count++;
var oversized = ArrayPool<RawResult>.Shared.Rent(count);
count = 0;
foreach (var token in line.GetInlineTokenizer())
{ {
// is it in the current span? (we need to handle the offsets differently if so) oversized[count++] = new RawResult(line.Type, token, false);
int totalSkipped = 0;
bool haveTrailingCR = false;
do
{
if (reader.RemainingThisSpan == 0) continue;
var span = reader.SlicedSpan;
if (haveTrailingCR)
{
if (span[0] == '\n') return totalSkipped - 1;
haveTrailingCR = false;
}
int found = span.IndexOf(CRLF);
if (found >= 0) return totalSkipped + found;
haveTrailingCR = span[span.Length - 1] == '\r';
totalSkipped += span.Length;
}
while (reader.FetchNextSegment());
return -1;
} }
return new RawResult(oversized, count);
}
//internal static bool HasBytes(BufferReader reader, int count) // very deliberately not ref; want snapshot
//{
// if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
// do
// {
// var available = reader.RemainingThisSpan;
// if (count <= available) return true;
// count -= available;
// } while (reader.FetchNextSegment());
// return false;
//}
public int ConsumeByte()
{
if (IsEmpty) return -1;
var value = _current[OffsetThisSpan];
Consume(1);
return value;
}
public ReadOnlySequence<byte> SliceFromCurrent()
{
var from = SnapshotPosition();
return _buffer.Slice(from);
}
}
} }
} }
...@@ -7,6 +7,15 @@ namespace StackExchange.Redis ...@@ -7,6 +7,15 @@ namespace StackExchange.Redis
{ {
internal readonly struct RawResult internal readonly struct RawResult
{ {
internal RawResult this[int index]
{
get
{
if (index >= _itemsCount) throw new IndexOutOfRangeException();
return _itemsOversized[index];
}
}
internal int ItemsCount => _itemsCount;
internal static readonly RawResult NullMultiBulk = new RawResult(null, 0); internal static readonly RawResult NullMultiBulk = new RawResult(null, 0);
internal static readonly RawResult EmptyMultiBulk = new RawResult(Array.Empty<RawResult>(), 0); internal static readonly RawResult EmptyMultiBulk = new RawResult(Array.Empty<RawResult>(), 0);
internal static readonly RawResult Nil = default; internal static readonly RawResult Nil = default;
...@@ -72,6 +81,58 @@ public override string ToString() ...@@ -72,6 +81,58 @@ public override string ToString()
} }
} }
public Tokenizer GetInlineTokenizer() => new Tokenizer(_payload);
internal ref struct Tokenizer
{
// tokenizes things according to the inline protocol
// specifically; the line: abc "def ghi" jkl
// is 3 tokens: "abc", "def ghi" and "jkl"
public Tokenizer GetEnumerator() => this;
BufferReader _value;
public Tokenizer(ReadOnlySequence<byte> value)
{
_value = new BufferReader(value);
Current = default;
}
public bool MoveNext()
{
Current = default;
// take any white-space
while (_value.PeekByte() == (byte)' ') { _value.Consume(1); }
byte terminator = (byte)' ';
var first = _value.PeekByte();
if (first < 0) return false; // EOF
switch (_value.PeekByte())
{
case (byte)'"':
case (byte)'\'':
// start of string
terminator = (byte)first;
_value.Consume(1);
break;
}
int end = BufferReader.FindNext(_value, terminator);
if (end < 0)
{
Current = _value.ConsumeToEnd();
}
else
{
Current = _value.ConsumeAsBuffer(end);
_value.Consume(1); // drop the terminator itself;
}
return true;
}
public ReadOnlySequence<byte> Current { get; private set; }
}
internal RedisChannel AsRedisChannel(byte[] channelPrefix, RedisChannel.PatternMode mode) internal RedisChannel AsRedisChannel(byte[] channelPrefix, RedisChannel.PatternMode mode)
{ {
switch (Type) switch (Type)
...@@ -202,6 +263,12 @@ internal ReadOnlySpan<RawResult> GetItems() ...@@ -202,6 +263,12 @@ internal ReadOnlySpan<RawResult> GetItems()
return new ReadOnlySpan<RawResult>(_itemsOversized, 0, _itemsCount); return new ReadOnlySpan<RawResult>(_itemsOversized, 0, _itemsCount);
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
internal ReadOnlyMemory<RawResult> GetItemsMemory()
{
if (Type == ResultType.MultiBulk)
return new ReadOnlyMemory<RawResult>(_itemsOversized, 0, _itemsCount);
throw new InvalidOperationException();
}
internal RedisKey[] GetItemsAsKeys() internal RedisKey[] GetItemsAsKeys()
{ {
......
...@@ -698,13 +698,13 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -698,13 +698,13 @@ protected override void WriteImpl(PhysicalConnection physical)
bool isCopy = (migrateOptions & MigrateOptions.Copy) != 0; bool isCopy = (migrateOptions & MigrateOptions.Copy) != 0;
bool isReplace = (migrateOptions & MigrateOptions.Replace) != 0; bool isReplace = (migrateOptions & MigrateOptions.Replace) != 0;
physical.WriteHeader(Command, 5 + (isCopy ? 1 : 0) + (isReplace ? 1 : 0)); physical.WriteHeader(Command, 5 + (isCopy ? 1 : 0) + (isReplace ? 1 : 0));
physical.Write(toHost); physical.WriteBulkString(toHost);
physical.Write(toPort); physical.WriteBulkString(toPort);
physical.Write(Key); physical.Write(Key);
physical.Write(toDatabase); physical.WriteBulkString(toDatabase);
physical.Write(timeoutMilliseconds); physical.WriteBulkString(timeoutMilliseconds);
if (isCopy) physical.Write(RedisLiterals.COPY); if (isCopy) physical.WriteBulkString(RedisLiterals.COPY);
if (isReplace) physical.Write(RedisLiterals.REPLACE); if (isReplace) physical.WriteBulkString(RedisLiterals.REPLACE);
} }
} }
...@@ -3169,8 +3169,8 @@ public ScriptLoadMessage(CommandFlags flags, string script) ...@@ -3169,8 +3169,8 @@ public ScriptLoadMessage(CommandFlags flags, string script)
protected override void WriteImpl(PhysicalConnection physical) protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2); physical.WriteHeader(Command, 2);
physical.Write(RedisLiterals.LOAD); physical.WriteBulkString(RedisLiterals.LOAD);
physical.Write((RedisValue)Script); physical.WriteBulkString((RedisValue)Script);
} }
} }
...@@ -3237,7 +3237,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -3237,7 +3237,7 @@ protected override void WriteImpl(PhysicalConnection physical)
{ // recognises well-known types { // recognises well-known types
var val = RedisValue.TryParse(arg); var val = RedisValue.TryParse(arg);
if (val.IsNull && arg != null) throw new InvalidCastException($"Unable to parse value: '{arg}'"); if (val.IsNull && arg != null) throw new InvalidCastException($"Unable to parse value: '{arg}'");
physical.Write(val); physical.WriteBulkString(val);
} }
} }
} }
...@@ -3334,18 +3334,18 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -3334,18 +3334,18 @@ protected override void WriteImpl(PhysicalConnection physical)
else if (asciiHash != null) else if (asciiHash != null)
{ {
physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length);
physical.Write((RedisValue)asciiHash); physical.WriteBulkString((RedisValue)asciiHash);
} }
else else
{ {
physical.WriteHeader(RedisCommand.EVAL, 2 + keys.Length + values.Length); physical.WriteHeader(RedisCommand.EVAL, 2 + keys.Length + values.Length);
physical.Write((RedisValue)script); physical.WriteBulkString((RedisValue)script);
} }
physical.Write(keys.Length); physical.WriteBulkString(keys.Length);
for (int i = 0; i < keys.Length; i++) for (int i = 0; i < keys.Length; i++)
physical.Write(keys[i]); physical.Write(keys[i]);
for (int i = 0; i < values.Length; i++) for (int i = 0; i < values.Length; i++)
physical.Write(values[i]); physical.WriteBulkString(values[i]);
} }
} }
...@@ -3386,11 +3386,11 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -3386,11 +3386,11 @@ protected override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(Command, 2 + keys.Length + values.Length); physical.WriteHeader(Command, 2 + keys.Length + values.Length);
physical.Write(Key); physical.Write(Key);
physical.Write(keys.Length); physical.WriteBulkString(keys.Length);
for (int i = 0; i < keys.Length; i++) for (int i = 0; i < keys.Length; i++)
physical.Write(keys[i]); physical.Write(keys[i]);
for (int i = 0; i < values.Length; i++) for (int i = 0; i < values.Length; i++)
physical.Write(values[i]); physical.WriteBulkString(values[i]);
} }
} }
......
...@@ -11,16 +11,37 @@ public abstract class RedisResult ...@@ -11,16 +11,37 @@ public abstract class RedisResult
/// Create a new RedisResult representing a single value. /// Create a new RedisResult representing a single value.
/// </summary> /// </summary>
/// <param name="value">The <see cref="RedisValue"/> to create a result from.</param> /// <param name="value">The <see cref="RedisValue"/> to create a result from.</param>
/// <param name="resultType">The type of result being represented</param>
/// <returns> new <see cref="RedisResult"/>.</returns> /// <returns> new <see cref="RedisResult"/>.</returns>
public static RedisResult Create(RedisValue value) => new SingleRedisResult(value, null); public static RedisResult Create(RedisValue value, ResultType? resultType = null) => new SingleRedisResult(value, resultType);
/// <summary> /// <summary>
/// Create a new RedisResult representing an array of values. /// Create a new RedisResult representing an array of values.
/// </summary> /// </summary>
/// <param name="values">The <see cref="RedisValue"/>s to create a result from.</param> /// <param name="values">The <see cref="RedisValue"/>s to create a result from.</param>
/// <returns> new <see cref="RedisResult"/>.</returns> /// <returns> new <see cref="RedisResult"/>.</returns>
public static RedisResult Create(RedisValue[] values) => new ArrayRedisResult( public static RedisResult Create(RedisValue[] values) =>
values == null ? null : Array.ConvertAll(values, value => new SingleRedisResult(value, null))); values == null ? NullArray : values.Length == 0 ? EmptyArray :
new ArrayRedisResult(Array.ConvertAll(values, value => new SingleRedisResult(value, null)));
/// <summary>
/// Create a new RedisResult representing an array of values.
/// </summary>
/// <param name="values">The <see cref="RedisResult"/>s to create a result from.</param>
/// <returns> new <see cref="RedisResult"/>.</returns>
public static RedisResult Create(RedisResult[] values)
=> values == null ? NullArray : values.Length == 0 ? EmptyArray : new ArrayRedisResult(values);
/// <summary>
/// An empty array result
/// </summary>
public static RedisResult EmptyArray { get; } = new ArrayRedisResult(Array.Empty<RedisResult>());
/// <summary>
/// A null array result
/// </summary>
public static RedisResult NullArray { get; } = new ArrayRedisResult(null);
// internally, this is very similar to RawResult, except it is designed to be usable // internally, this is very similar to RawResult, except it is designed to be usable
// outside of the IO-processing pipeline: the buffers are standalone, etc // outside of the IO-processing pipeline: the buffers are standalone, etc
...@@ -36,8 +57,10 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r ...@@ -36,8 +57,10 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r
case ResultType.BulkString: case ResultType.BulkString:
return new SingleRedisResult(result.AsRedisValue(), result.Type); return new SingleRedisResult(result.AsRedisValue(), result.Type);
case ResultType.MultiBulk: case ResultType.MultiBulk:
if (result.IsNull) return NullArray;
var items = result.GetItems(); var items = result.GetItems();
var arr = result.IsNull ? null : new RedisResult[items.Length]; if (items.Length == 0) return EmptyArray;
var arr = new RedisResult[items.Length];
for (int i = 0; i < arr.Length; i++) for (int i = 0; i < arr.Length; i++)
{ {
var next = TryCreate(connection, items[i]); var next = TryCreate(connection, items[i]);
...@@ -50,7 +73,8 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r ...@@ -50,7 +73,8 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r
default: default:
return null; return null;
} }
} catch (Exception ex) }
catch (Exception ex)
{ {
connection?.OnInternalError(ex); connection?.OnInternalError(ex);
return null; // will be logged as a protocol fail by the processor return null; // will be logged as a protocol fail by the processor
...@@ -66,6 +90,23 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r ...@@ -66,6 +90,23 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r
/// Indicates whether this result was a null result /// Indicates whether this result was a null result
/// </summary> /// </summary>
public abstract bool IsNull { get; } public abstract bool IsNull { get; }
/// <summary>
/// A successful result
/// </summary>
public static RedisResult OK { get; } = Create("OK", ResultType.SimpleString);
/// <summary>
/// An integer-zero result
/// </summary>
public static RedisResult Zero { get; } = Create(0, ResultType.Integer);
/// <summary>
/// An integer-one result
/// </summary>
public static RedisResult One { get; } = Create(1, ResultType.Integer);
/// <summary>
/// A null bulk-string result
/// </summary>
public static RedisResult Null { get; } = Create(RedisValue.Null, ResultType.BulkString);
/// <summary> /// <summary>
/// Interprets the result as a <see cref="string"/>. /// Interprets the result as a <see cref="string"/>.
...@@ -196,108 +237,140 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r ...@@ -196,108 +237,140 @@ internal static RedisResult TryCreate(PhysicalConnection connection, RawResult r
internal abstract string[] AsStringArray(); internal abstract string[] AsStringArray();
private sealed class ArrayRedisResult : RedisResult private sealed class ArrayRedisResult : RedisResult
{ {
public override bool IsNull => value == null; public override bool IsNull => _value == null;
private readonly RedisResult[] value; private readonly RedisResult[] _value;
public override ResultType Type => ResultType.MultiBulk; public override ResultType Type => ResultType.MultiBulk;
public ArrayRedisResult(RedisResult[] value) public ArrayRedisResult(RedisResult[] value)
{ {
this.value = value ?? throw new ArgumentNullException(nameof(value)); _value = value;
} }
public override string ToString() => value.Length + " element(s)"; public override string ToString() => _value == null ? "(nil)" : (_value.Length + " element(s)");
internal override bool AsBoolean() internal override bool AsBoolean()
{ {
if (value.Length == 1) return value[0].AsBoolean(); if (IsSingleton) return _value[0].AsBoolean();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override bool[] AsBooleanArray() => Array.ConvertAll(value, x => x.AsBoolean()); internal override bool[] AsBooleanArray() => IsNull ? null : Array.ConvertAll(_value, x => x.AsBoolean());
internal override byte[] AsByteArray() internal override byte[] AsByteArray()
{ {
if (value.Length == 1) return value[0].AsByteArray(); if (IsSingleton) return _value[0].AsByteArray();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override byte[][] AsByteArrayArray() => Array.ConvertAll(value, x => x.AsByteArray()); internal override byte[][] AsByteArrayArray()
=> IsNull ? null
: _value.Length == 0 ? Array.Empty<byte[]>()
: Array.ConvertAll(_value, x => x.AsByteArray());
private bool IsSingleton => _value != null && _value.Length == 1;
private bool IsEmpty => _value != null && _value.Length == 0;
internal override double AsDouble() internal override double AsDouble()
{ {
if (value.Length == 1) return value[0].AsDouble(); if (IsSingleton) return _value[0].AsDouble();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override double[] AsDoubleArray() => Array.ConvertAll(value, x => x.AsDouble()); internal override double[] AsDoubleArray()
=> IsNull ? null
: IsEmpty ? Array.Empty<double>()
: Array.ConvertAll(_value, x => x.AsDouble());
internal override int AsInt32() internal override int AsInt32()
{ {
if (value.Length == 1) return value[0].AsInt32(); if (IsSingleton) return _value[0].AsInt32();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override int[] AsInt32Array() => Array.ConvertAll(value, x => x.AsInt32()); internal override int[] AsInt32Array()
=> IsNull ? null
: IsEmpty ? Array.Empty<int>()
: Array.ConvertAll(_value, x => x.AsInt32());
internal override long AsInt64() internal override long AsInt64()
{ {
if (value.Length == 1) return value[0].AsInt64(); if (IsSingleton) return _value[0].AsInt64();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override long[] AsInt64Array() => Array.ConvertAll(value, x => x.AsInt64()); internal override long[] AsInt64Array()
=> IsNull ? null
: IsEmpty ? Array.Empty<long>()
: Array.ConvertAll(_value, x => x.AsInt64());
internal override bool? AsNullableBoolean() internal override bool? AsNullableBoolean()
{ {
if (value.Length == 1) return value[0].AsNullableBoolean(); if (IsSingleton) return _value[0].AsNullableBoolean();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override double? AsNullableDouble() internal override double? AsNullableDouble()
{ {
if (value.Length == 1) return value[0].AsNullableDouble(); if (IsSingleton) return _value[0].AsNullableDouble();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override int? AsNullableInt32() internal override int? AsNullableInt32()
{ {
if (value.Length == 1) return value[0].AsNullableInt32(); if (IsSingleton) return _value[0].AsNullableInt32();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override long? AsNullableInt64() internal override long? AsNullableInt64()
{ {
if (value.Length == 1) return value[0].AsNullableInt64(); if (IsSingleton) return _value[0].AsNullableInt64();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override RedisKey AsRedisKey() internal override RedisKey AsRedisKey()
{ {
if (value.Length == 1) return value[0].AsRedisKey(); if (IsSingleton) return _value[0].AsRedisKey();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override RedisKey[] AsRedisKeyArray() => Array.ConvertAll(value, x => x.AsRedisKey()); internal override RedisKey[] AsRedisKeyArray()
=> IsNull ? null
: IsEmpty ? Array.Empty<RedisKey>()
: Array.ConvertAll(_value, x => x.AsRedisKey());
internal override RedisResult[] AsRedisResultArray() => value; internal override RedisResult[] AsRedisResultArray() => _value;
internal override RedisValue AsRedisValue() internal override RedisValue AsRedisValue()
{ {
if (value.Length == 1) return value[0].AsRedisValue(); if (IsSingleton) return _value[0].AsRedisValue();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override RedisValue[] AsRedisValueArray() => Array.ConvertAll(value, x => x.AsRedisValue()); internal override RedisValue[] AsRedisValueArray()
=> IsNull ? null
: IsEmpty ? Array.Empty<RedisValue>()
: Array.ConvertAll(_value, x => x.AsRedisValue());
internal override string AsString() internal override string AsString()
{ {
if (value.Length == 1) return value[0].AsString(); if (IsSingleton) return _value[0].AsString();
throw new InvalidCastException(); throw new InvalidCastException();
} }
internal override string[] AsStringArray() => Array.ConvertAll(value, x => x.AsString()); internal override string[] AsStringArray()
=> IsNull ? null
: IsEmpty ? Array.Empty<string>()
: Array.ConvertAll(_value, x => x.AsString());
} }
/// <summary>
/// Create a RedisResult from a key
/// </summary>
public static RedisResult Create(RedisKey key) => Create(key.AsRedisValue(), ResultType.BulkString);
/// <summary>
/// Create a RedisResult from a channel
/// </summary>
public static RedisResult Create(RedisChannel channel) => Create((byte[])channel, ResultType.BulkString);
private sealed class ErrorRedisResult : RedisResult private sealed class ErrorRedisResult : RedisResult
{ {
private readonly string value; private readonly string value;
......
...@@ -26,8 +26,32 @@ private RedisValue(long overlappedValue64, ReadOnlyMemory<byte> memory, object o ...@@ -26,8 +26,32 @@ private RedisValue(long overlappedValue64, ReadOnlyMemory<byte> memory, object o
} }
private readonly static object Sentinel_Integer = new object(); private readonly static object Sentinel_Integer = new object();
private readonly static object Sentinel_Raw = new object(); private readonly static object Sentinel_Raw = new object();
private readonly static object Sentinel_Double = new object(); private readonly static object Sentinel_Double = new object();
/// <summary>
/// Obtain this value as an object - to be used alongside Unbox
/// </summary>
public object Box()
{
var obj = _objectOrSentinel;
if (obj is null || obj is string || obj is byte[]) return obj;
return this;
}
/// <summary>
/// Parse this object as a value - to be used alongside Box
/// </summary>
public static RedisValue Unbox(object value)
{
if (value == null) return RedisValue.Null;
if (value is string s) return s;
if (value is byte[] b) return b;
return (RedisValue)value;
}
/// <summary> /// <summary>
/// Represents the string <c>""</c> /// Represents the string <c>""</c>
/// </summary> /// </summary>
...@@ -305,6 +329,20 @@ internal StorageType Type ...@@ -305,6 +329,20 @@ internal StorageType Type
} }
} }
/// <summary>
/// Get the size of this value in bytes
/// </summary>
public long Length()
{
switch(Type)
{
case StorageType.Null: return 0;
case StorageType.Raw: return _memory.Length;
case StorageType.String: return Encoding.UTF8.GetByteCount((string)_objectOrSentinel);
default: throw new InvalidOperationException("Unable to compute length of type: " + Type);
}
}
/// <summary> /// <summary>
/// Compare against a RedisValue for relative order /// Compare against a RedisValue for relative order
/// </summary> /// </summary>
......
...@@ -364,7 +364,7 @@ protected override void WriteImpl(PhysicalConnection physical) ...@@ -364,7 +364,7 @@ protected override void WriteImpl(PhysicalConnection physical)
else else
{ {
physical.WriteHeader(command, 1); physical.WriteHeader(command, 1);
physical.Write(value); physical.WriteBulkString(value);
} }
} }
} }
......
using System; using System;
using System.Diagnostics; using System.Net;
using System.Threading.Tasks;
using StackExchange.Redis.Server;
namespace TestConsole static class Program
{ {
internal static class Program static async Task Main()
{ {
private static int Main() using (var server = new MemoryCacheRedisServer(Console.Out))
{ {
try server.Listen(new IPEndPoint(IPAddress.Loopback, 6378));
{ await server.Shutdown;
using (var obj = new BasicTest.RedisBenchmarks())
{
var watch = Stopwatch.StartNew();
obj.ExecuteIncrBy();
watch.Stop();
Console.WriteLine($"{watch.ElapsedMilliseconds}ms");
}
return 0;
}
catch (Exception ex)
{
Console.Error.WriteLine(ex.Message);
return -1;
}
} }
} }
} }
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\BasicTest\BasicTest.csproj" /> <ProjectReference Include="..\BasicTest\BasicTest.csproj" />
<ProjectReference Include="..\StackExchange.Redis.Server\StackExchange.Redis.Server.csproj" />
<ProjectReference Include="..\StackExchange.Redis.Tests\StackExchange.Redis.Tests.csproj" />
<ProjectReference Include="..\StackExchange.Redis\StackExchange.Redis.csproj" /> <ProjectReference Include="..\StackExchange.Redis\StackExchange.Redis.csproj" />
</ItemGroup> </ItemGroup>
</Project> </Project>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment