Commit e931d6a7 authored by Marc Gravell's avatar Marc Gravell

Per-server script cache (aka EVALSHA)

parent 0130f2ee
using System; using System;
using System.Diagnostics;
using NUnit.Framework; using NUnit.Framework;
namespace StackExchange.Redis.Tests namespace StackExchange.Redis.Tests
...@@ -27,5 +28,104 @@ public void TestBasicScripting() ...@@ -27,5 +28,104 @@ public void TestBasicScripting()
Assert.IsFalse(wasSet); Assert.IsFalse(wasSet);
} }
} }
[Test]
public void CheckLoads()
{
using (var conn0 = Create(allowAdmin: true))
using (var conn1 = Create(allowAdmin: true))
{
// note that these are on different connections (so we wouldn't expect
// the flush to drop the local cache - assume it is a surprise!)
var server = conn0.GetServer(PrimaryServer, PrimaryPort);
var db = conn1.GetDatabase();
const string script = "return 1;";
// start empty
server.ScriptFlush();
Assert.IsFalse(server.ScriptExists(script));
// run once, causes to be cached
Assert.IsTrue((bool)db.ScriptEvaluate(script));
Assert.IsTrue(server.ScriptExists(script));
// can run again
Assert.IsTrue((bool)db.ScriptEvaluate(script));
// ditch the scripts; should no longer exist
db.Ping();
server.ScriptFlush();
Assert.IsFalse(server.ScriptExists(script));
db.Ping();
// now: fails the first time
try
{
Assert.IsTrue((bool)db.ScriptEvaluate(script));
Assert.Fail();
} catch(RedisServerException ex)
{
Assert.IsTrue(ex.Message == "NOSCRIPT No matching script. Please use EVAL.");
}
// but gets marked as unloaded, so we can use it again...
Assert.IsTrue((bool)db.ScriptEvaluate(script));
// which will cause it to be cached
Assert.IsTrue(server.ScriptExists(script));
}
}
[Test]
public void CompareScriptToDirect()
{
const string Script = "return redis.call('incr', KEYS[1])";
using (var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
server.ScriptLoad(Script);
var db = conn.GetDatabase();
db.Ping(); // k, we're all up to date now; clean db, minimal script cache
// we're using a pipeline here, so send 1000 messages, but for timing: only care about the last
const int LOOP = 5000;
RedisKey key = "foo";
RedisKey[] keys = new[] { key }; // script takes an array
// run via script
db.KeyDelete(key);
CollectGarbage();
var watch = Stopwatch.StartNew();
for(int i = 1; i < LOOP; i++) // the i=1 is to do all-but-one
{
db.ScriptEvaluate(Script, keys, flags: CommandFlags.FireAndForget);
}
var scriptResult = db.ScriptEvaluate(Script, keys); // last one we wait for (no F+F)
watch.Stop();
TimeSpan scriptTime = watch.Elapsed;
// run via raw op
db.KeyDelete(key);
CollectGarbage();
watch = Stopwatch.StartNew();
for (int i = 1; i < LOOP; i++) // the i=1 is to do all-but-one
{
db.StringIncrement(key, flags: CommandFlags.FireAndForget);
}
var directResult = db.StringIncrement(key); // last one we wait for (no F+F)
watch.Stop();
TimeSpan directTime = watch.Elapsed;
Assert.AreEqual(LOOP, (long)scriptResult, "script result");
Assert.AreEqual(LOOP, (long)directResult, "direct result");
Console.WriteLine("script: {0}ms; direct: {1}ms",
scriptTime.TotalMilliseconds,
directTime.TotalMilliseconds);
}
}
} }
} }
...@@ -16,6 +16,14 @@ namespace StackExchange.Redis.Tests ...@@ -16,6 +16,14 @@ namespace StackExchange.Redis.Tests
public abstract class TestBase : IDisposable public abstract class TestBase : IDisposable
{ {
protected void CollectGarbage()
{
for (int i = 0; i < 3; i++)
{
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced);
GC.WaitForPendingFinalizers();
}
}
private readonly SocketManager socketManager; private readonly SocketManager socketManager;
protected TestBase() protected TestBase()
...@@ -138,6 +146,12 @@ protected IServer GetServer(ConnectionMultiplexer muxer) ...@@ -138,6 +146,12 @@ protected IServer GetServer(ConnectionMultiplexer muxer)
map[cmd] = null; map[cmd] = null;
config.CommandMap = CommandMap.Create(map); config.CommandMap = CommandMap.Create(map);
} }
if(Debugger.IsAttached)
{
syncTimeout = int.MaxValue;
}
if (useSharedSocketManager) config.SocketManager = socketManager; if (useSharedSocketManager) config.SocketManager = socketManager;
if (channelPrefix != null) config.ChannelPrefix = channelPrefix; if (channelPrefix != null) config.ChannelPrefix = channelPrefix;
if (tieBreaker != null) config.TieBreaker = tieBreaker; if (tieBreaker != null) config.TieBreaker = tieBreaker;
...@@ -149,7 +163,7 @@ protected IServer GetServer(ConnectionMultiplexer muxer) ...@@ -149,7 +163,7 @@ protected IServer GetServer(ConnectionMultiplexer muxer)
if (connectTimeout != null) config.ConnectTimeout = connectTimeout.Value; if (connectTimeout != null) config.ConnectTimeout = connectTimeout.Value;
var watch = Stopwatch.StartNew(); var watch = Stopwatch.StartNew();
var task = ConnectionMultiplexer.ConnectAsync(config, log ?? Console.Out); var task = ConnectionMultiplexer.ConnectAsync(config, log ?? Console.Out);
if (!task.Wait(config.ConnectTimeout * 2)) if (!task.Wait(config.ConnectTimeout >= (int.MaxValue / 2) ? int.MaxValue : config.ConnectTimeout * 2))
{ {
task.ContinueWith(x => task.ContinueWith(x =>
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Collections.Generic;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
......
...@@ -253,6 +253,46 @@ public interface IServer : IRedis ...@@ -253,6 +253,46 @@ public interface IServer : IRedis
/// <remarks>http://redis.io/topics/persistence</remarks> /// <remarks>http://redis.io/topics/persistence</remarks>
Task SaveAsync(SaveType type, CommandFlags flags = CommandFlags.None); Task SaveAsync(SaveType type, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Inidicates whether the specified script is defined on the server
/// </summary>
bool ScriptExists(string script, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Inidicates whether the specified script hash is defined on the server
/// </summary>
bool ScriptExists(byte[] sha1, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Inidicates whether the specified script is defined on the server
/// </summary>
Task<bool> ScriptExistsAsync(string script, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Inidicates whether the specified script hash is defined on the server
/// </summary>
Task<bool> ScriptExistsAsync(byte[] sha1, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Removes all cached scripts on this server
/// </summary>
void ScriptFlush(CommandFlags flags = CommandFlags.None);
/// <summary>
/// Removes all cached scripts on this server
/// </summary>
Task ScriptFlushAsync(CommandFlags flags = CommandFlags.None);
/// <summary>
/// Explicitly defines a script on the server
/// </summary>
byte[] ScriptLoad(string script, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Explicitly defines a script on the server
/// </summary>
Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandFlags.None);
/// <summary>Asks the redis server to shutdown, killing all connections. Please FULLY read the notes on the SHUTDOWN command.</summary> /// <summary>Asks the redis server to shutdown, killing all connections. Please FULLY read the notes on the SHUTDOWN command.</summary>
/// <remarks>http://redis.io/commands/shutdown</remarks> /// <remarks>http://redis.io/commands/shutdown</remarks>
void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None); void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None);
......
...@@ -469,6 +469,7 @@ internal static bool RequiresDatabase(RedisCommand command) ...@@ -469,6 +469,7 @@ internal static bool RequiresDatabase(RedisCommand command)
case RedisCommand.READONLY: case RedisCommand.READONLY:
case RedisCommand.READWRITE: case RedisCommand.READWRITE:
case RedisCommand.SAVE: case RedisCommand.SAVE:
case RedisCommand.SCRIPT:
case RedisCommand.SHUTDOWN: case RedisCommand.SHUTDOWN:
case RedisCommand.SLAVEOF: case RedisCommand.SLAVEOF:
case RedisCommand.SLOWLOG: case RedisCommand.SLOWLOG:
......
...@@ -280,7 +280,7 @@ internal void OnConnectionFailed(PhysicalConnection connection, ConnectionFailur ...@@ -280,7 +280,7 @@ internal void OnConnectionFailed(PhysicalConnection connection, ConnectionFailur
} }
} }
internal void OnDisconnected(ConnectionFailureType failureType, PhysicalConnection connection, out bool isCurrent) internal void OnDisconnected(ConnectionFailureType failureType, PhysicalConnection connection, out bool isCurrent, out State oldState)
{ {
Trace("OnDisconnected"); Trace("OnDisconnected");
...@@ -294,11 +294,11 @@ internal void OnDisconnected(ConnectionFailureType failureType, PhysicalConnecti ...@@ -294,11 +294,11 @@ internal void OnDisconnected(ConnectionFailureType failureType, PhysicalConnecti
ping.Fail(failureType, null); ping.Fail(failureType, null);
CompleteSyncOrAsync(ping); CompleteSyncOrAsync(ping);
} }
oldState = default(State); // only defined when isCurrent = true
if (isCurrent = (physical == connection)) if (isCurrent = (physical == connection))
{ {
Trace("Bridge noting disconnect from active connection" + (isDisposed ? " (disposed)" : "")); Trace("Bridge noting disconnect from active connection" + (isDisposed ? " (disposed)" : ""));
ChangeState(State.Disconnected); oldState = ChangeState(State.Disconnected);
physical = null; physical = null;
if (!isDisposed && Interlocked.Increment(ref failConnectCount) == 1) if (!isDisposed && Interlocked.Increment(ref failConnectCount) == 1)
...@@ -367,7 +367,8 @@ internal void OnHeartbeat() ...@@ -367,7 +367,8 @@ internal void OnHeartbeat()
else else
{ {
bool ignore; bool ignore;
OnDisconnected(ConnectionFailureType.SocketFailure, tmp, out ignore); State oldState;
OnDisconnected(ConnectionFailureType.SocketFailure, tmp, out ignore, out oldState);
} }
} }
} }
...@@ -463,6 +464,7 @@ internal void WriteMessageDirect(PhysicalConnection tmp, Message next) ...@@ -463,6 +464,7 @@ internal void WriteMessageDirect(PhysicalConnection tmp, Message next)
{ {
// we screwed up; abort; note that WriteMessageToServer already // we screwed up; abort; note that WriteMessageToServer already
// killed the underlying connection // killed the underlying connection
Trace("Unable to write to server");
next.Fail(ConnectionFailureType.ProtocolFailure, null); next.Fail(ConnectionFailureType.ProtocolFailure, null);
CompleteSyncOrAsync(next); CompleteSyncOrAsync(next);
break; break;
...@@ -475,13 +477,14 @@ internal void WriteMessageDirect(PhysicalConnection tmp, Message next) ...@@ -475,13 +477,14 @@ internal void WriteMessageDirect(PhysicalConnection tmp, Message next)
} }
} }
private void ChangeState(State newState) private State ChangeState(State newState)
{ {
var oldState = (State)Interlocked.Exchange(ref state, (int)newState); var oldState = (State)Interlocked.Exchange(ref state, (int)newState);
if (oldState != newState) if (oldState != newState)
{ {
multiplexer.Trace(connectionType + " state changed from " + oldState + " to " + newState); multiplexer.Trace(connectionType + " state changed from " + oldState + " to " + newState);
} }
return oldState;
} }
private bool ChangeState(State oldState, State newState) private bool ChangeState(State oldState, State newState)
......
...@@ -71,7 +71,8 @@ private static readonly Message ...@@ -71,7 +71,8 @@ private static readonly Message
public PhysicalConnection(PhysicalBridge bridge) public PhysicalConnection(PhysicalBridge bridge)
{ {
lastWriteTickCount = lastReadTickCount = lastBeatTickCount = Environment.TickCount; lastWriteTickCount = lastReadTickCount = Environment.TickCount;
lastBeatTickCount = 0;
this.connectionType = bridge.ConnectionType; this.connectionType = bridge.ConnectionType;
this.multiplexer = bridge.Multiplexer; this.multiplexer = bridge.Multiplexer;
this.ChannelPrefix = multiplexer.RawConfig.ChannelPrefix; this.ChannelPrefix = multiplexer.RawConfig.ChannelPrefix;
...@@ -151,7 +152,8 @@ public void RecordConnectionFailed(ConnectionFailureType failureType, Exception ...@@ -151,7 +152,8 @@ public void RecordConnectionFailed(ConnectionFailureType failureType, Exception
// stop anything new coming in... // stop anything new coming in...
bridge.Trace("Failed: " + failureType); bridge.Trace("Failed: " + failureType);
bool isCurrent; bool isCurrent;
bridge.OnDisconnected(failureType, this, out isCurrent); PhysicalBridge.State oldState;
bridge.OnDisconnected(failureType, this, out isCurrent, out oldState);
if (isCurrent && Interlocked.CompareExchange(ref failureReported, 1, 0) == 0) if (isCurrent && Interlocked.CompareExchange(ref failureReported, 1, 0) == 0)
{ {
...@@ -163,7 +165,7 @@ public void RecordConnectionFailed(ConnectionFailureType failureType, Exception ...@@ -163,7 +165,7 @@ public void RecordConnectionFailed(ConnectionFailureType failureType, Exception
string message = failureType + " on " + Format.ToString(bridge.ServerEndPoint.EndPoint) + "/" + connectionType string message = failureType + " on " + Format.ToString(bridge.ServerEndPoint.EndPoint) + "/" + connectionType
+ ", input-buffer: " + ioBufferBytes + ", outstanding: " + GetOutstandingCount() + ", input-buffer: " + ioBufferBytes + ", outstanding: " + GetOutstandingCount()
+ ", last-read: " + unchecked(now - lastRead) / 1000 + "s ago, last-write: " + unchecked(now - lastWrite) / 1000 + "s ago, keep-alive: " + bridge.ServerEndPoint.WriteEverySeconds + "s, pending: " + ", last-read: " + unchecked(now - lastRead) / 1000 + "s ago, last-write: " + unchecked(now - lastWrite) / 1000 + "s ago, keep-alive: " + bridge.ServerEndPoint.WriteEverySeconds + "s, pending: "
+ bridge.GetPendingCount() + ", last-heartbeat: " + unchecked(now - lastBeat) / 1000 + "s ago"; + bridge.GetPendingCount() + ", state: " + oldState + ", last-heartbeat: " + (lastBeat == 0 ? "never" : (unchecked(now - lastBeat) / 1000 + "s ago"));
var ex = innerException == null var ex = innerException == null
? new RedisConnectionException(failureType, message) ? new RedisConnectionException(failureType, message)
......
...@@ -108,6 +108,7 @@ enum RedisCommand ...@@ -108,6 +108,7 @@ enum RedisCommand
SAVE, SAVE,
SCAN, SCAN,
SCARD, SCARD,
SCRIPT,
SDIFF, SDIFF,
SDIFFSTORE, SDIFFSTORE,
SELECT, SELECT,
......
...@@ -676,13 +676,13 @@ public Task<RedisKey> RandomKeyAsync(CommandFlags flags = CommandFlags.None) ...@@ -676,13 +676,13 @@ public Task<RedisKey> RandomKeyAsync(CommandFlags flags = CommandFlags.None)
public RedisResult ScriptEvaluate(string script, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None) public RedisResult ScriptEvaluate(string script, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None)
{ {
var msg = new ScriptEvalMessage(Db, flags, RedisCommand.EVAL, script, keys ?? RedisKey.EmptyArray, values ?? RedisValue.EmptyArray); var msg = new ScriptEvalMessage(Db, flags, RedisCommand.EVAL, script, keys ?? RedisKey.EmptyArray, values ?? RedisValue.EmptyArray);
return ExecuteSync(msg, ResultProcessor.RedisResult); return ExecuteSync(msg, ResultProcessor.ScriptResult);
} }
public Task<RedisResult> ScriptEvaluateAsync(string script, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None) public Task<RedisResult> ScriptEvaluateAsync(string script, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None)
{ {
var msg = new ScriptEvalMessage(Db, flags, RedisCommand.EVAL, script, keys ?? RedisKey.EmptyArray, values ?? RedisValue.EmptyArray); var msg = new ScriptEvalMessage(Db, flags, RedisCommand.EVAL, script, keys ?? RedisKey.EmptyArray, values ?? RedisValue.EmptyArray);
return ExecuteAsync(msg, ResultProcessor.RedisResult); return ExecuteAsync(msg, ResultProcessor.ScriptResult);
} }
public bool SetAdd(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) public bool SetAdd(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None)
...@@ -1797,6 +1797,22 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -1797,6 +1797,22 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
} }
} }
internal sealed class ScriptLoadMessage : Message
{
internal readonly string Script;
public ScriptLoadMessage(CommandFlags flags, string script) : base(-1, flags, RedisCommand.SCRIPT)
{
if (script == null) throw new ArgumentNullException("script");
this.Script = script;
}
internal override void WriteImpl(PhysicalConnection physical)
{
physical.WriteHeader(Command, 2);
physical.Write(RedisLiterals.LOAD);
physical.Write((RedisValue)Script);
}
}
internal sealed class SetScanIterator internal sealed class SetScanIterator
{ {
internal const int DefaultPageSize = 10; internal const int DefaultPageSize = 10;
...@@ -1880,15 +1896,15 @@ Message CreateMessage(long cursor, bool running) ...@@ -1880,15 +1896,15 @@ Message CreateMessage(long cursor, bool running)
} }
} }
} }
private sealed class ScriptEvalMessage : Message, IMultiMessage
private sealed class ScriptEvalMessage : Message
{ {
private readonly RedisKey[] keys; private readonly RedisKey[] keys;
private readonly RedisValue script; private readonly string script;
private readonly RedisValue[] values; private readonly RedisValue[] values;
private RedisValue hash;
public ScriptEvalMessage(int db, CommandFlags flags, RedisCommand command, string script, RedisKey[] keys, RedisValue[] values) : base(db, flags, command) public ScriptEvalMessage(int db, CommandFlags flags, RedisCommand command, string script, RedisKey[] keys, RedisValue[] values) : base(db, flags, command)
{ {
if (script == null) throw new ArgumentNullException("script");
this.script = script; this.script = script;
for (int i = 0; i < keys.Length; i++) for (int i = 0; i < keys.Length; i++)
keys[i].Assert(); keys[i].Assert();
...@@ -1905,10 +1921,31 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) ...@@ -1905,10 +1921,31 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy)
return slot; return slot;
} }
public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{
this.hash = connection.Bridge.ServerEndPoint.GetScriptHash(script);
if(hash.IsNull)
{
var msg = new ScriptLoadMessage(Flags, script);
msg.SetInternalCall();
msg.SetSource(ResultProcessor.ScriptLoad, null);
yield return msg;
}
yield return this;
}
internal override void WriteImpl(PhysicalConnection physical) internal override void WriteImpl(PhysicalConnection physical)
{ {
physical.WriteHeader(command, 2 + keys.Length + values.Length); if(hash.IsNull)
physical.Write(script); {
physical.WriteHeader(RedisCommand.EVAL, 2 + keys.Length + values.Length);
physical.Write((RedisValue)script);
}
else
{
physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length);
physical.Write(hash);
}
physical.Write(keys.Length); physical.Write(keys.Length);
for (int i = 0; i < keys.Length; i++) for (int i = 0; i < keys.Length; i++)
physical.Write(keys[i]); physical.Write(keys[i]);
......
...@@ -48,6 +48,9 @@ public static readonly RedisValue ...@@ -48,6 +48,9 @@ public static readonly RedisValue
MIN = "MIN", MIN = "MIN",
MAX = "MAX", MAX = "MAX",
AGGREGATE = "AGGREGATE", AGGREGATE = "AGGREGATE",
LOAD = "LOAD",
EXISTS = "EXISTS",
FLUSH = "FLUSH",
// DO NOT CHANGE CASE: these are configuration settings and MUST be as-is // DO NOT CHANGE CASE: these are configuration settings and MUST be as-is
databases = "databases", databases = "databases",
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace StackExchange.Redis namespace StackExchange.Redis
...@@ -138,6 +140,18 @@ public Task ConfigSetAsync(RedisValue setting, RedisValue value, CommandFlags fl ...@@ -138,6 +140,18 @@ public Task ConfigSetAsync(RedisValue setting, RedisValue value, CommandFlags fl
ExecuteSync(Message.Create(-1, flags | CommandFlags.FireAndForget, RedisCommand.CONFIG, RedisLiterals.GET, setting), ResultProcessor.AutoConfigure); ExecuteSync(Message.Create(-1, flags | CommandFlags.FireAndForget, RedisCommand.CONFIG, RedisLiterals.GET, setting), ResultProcessor.AutoConfigure);
return task; return task;
} }
public long DatabaseSize(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.DBSIZE);
return ExecuteSync(msg, ResultProcessor.Int64);
}
public Task<long> DatabaseSizeAsync(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.DBSIZE);
return ExecuteAsync(msg, ResultProcessor.Int64);
}
public void FlushAllDatabases(CommandFlags flags = CommandFlags.None) public void FlushAllDatabases(CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(-1, flags, RedisCommand.FLUSHALL); var msg = Message.Create(-1, flags, RedisCommand.FLUSHALL);
...@@ -149,6 +163,18 @@ public Task FlushAllDatabasesAsync(CommandFlags flags = CommandFlags.None) ...@@ -149,6 +163,18 @@ public Task FlushAllDatabasesAsync(CommandFlags flags = CommandFlags.None)
return ExecuteAsync(msg, ResultProcessor.DemandOK); return ExecuteAsync(msg, ResultProcessor.DemandOK);
} }
public void FlushDatabase(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.FLUSHDB);
ExecuteSync(msg, ResultProcessor.DemandOK);
}
public Task FlushDatabaseAsync(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.FLUSHDB);
return ExecuteAsync(msg, ResultProcessor.DemandOK);
}
public ServerCounters GetCounters() public ServerCounters GetCounters()
{ {
return server.GetCounters(); return server.GetCounters();
...@@ -222,15 +248,6 @@ public void MakeMaster(ReplicationChangeOptions options, TextWriter log = null) ...@@ -222,15 +248,6 @@ public void MakeMaster(ReplicationChangeOptions options, TextWriter log = null)
{ {
multiplexer.MakeMaster(server, options, log); multiplexer.MakeMaster(server, options, log);
} }
Message GetSaveMessage(SaveType type, CommandFlags flags = CommandFlags.None)
{
switch(type)
{
case SaveType.BackgroundRewriteAppendOnlyFile: return Message.Create(-1, flags, RedisCommand.BGREWRITEAOF);
case SaveType.BackgroundSave: return Message.Create(-1, flags, RedisCommand.BGSAVE);
default: throw new ArgumentOutOfRangeException("type");
}
}
public void Save(SaveType type, CommandFlags flags = CommandFlags.None) public void Save(SaveType type, CommandFlags flags = CommandFlags.None)
{ {
var msg = GetSaveMessage(type, flags); var msg = GetSaveMessage(type, flags);
...@@ -243,6 +260,56 @@ public Task SaveAsync(SaveType type, CommandFlags flags = CommandFlags.None) ...@@ -243,6 +260,56 @@ public Task SaveAsync(SaveType type, CommandFlags flags = CommandFlags.None)
return ExecuteAsync(msg, ResultProcessor.DemandOK); return ExecuteAsync(msg, ResultProcessor.DemandOK);
} }
public bool ScriptExists(string script, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.EXISTS, ScriptHash.Hash(script));
return ExecuteSync(msg, ResultProcessor.Boolean);
}
public bool ScriptExists(byte[] sha1, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.EXISTS, ScriptHash.Encode(sha1));
return ExecuteSync(msg, ResultProcessor.Boolean);
}
public Task<bool> ScriptExistsAsync(string script, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.EXISTS, ScriptHash.Hash(script));
return ExecuteAsync(msg, ResultProcessor.Boolean);
}
public Task<bool> ScriptExistsAsync(byte[] sha1, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.EXISTS, ScriptHash.Encode(sha1));
return ExecuteAsync(msg, ResultProcessor.Boolean);
}
public void ScriptFlush(CommandFlags flags = CommandFlags.None)
{
if (!multiplexer.RawConfig.AllowAdmin) throw ExceptionFactory.AdminModeNotEnabled(RedisCommand.SCRIPT);
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.FLUSH);
ExecuteSync(msg, ResultProcessor.DemandOK);
}
public Task ScriptFlushAsync(CommandFlags flags = CommandFlags.None)
{
if (!multiplexer.RawConfig.AllowAdmin) throw ExceptionFactory.AdminModeNotEnabled(RedisCommand.SCRIPT);
var msg = Message.Create(-1, flags, RedisCommand.SCRIPT, RedisLiterals.FLUSH);
return ExecuteAsync(msg, ResultProcessor.DemandOK);
}
public byte[] ScriptLoad(string script, CommandFlags flags = CommandFlags.None)
{
var msg = new RedisDatabase.ScriptLoadMessage(flags, script);
return ExecuteSync(msg, ResultProcessor.ScriptLoad);
}
public Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandFlags.None)
{
var msg = new RedisDatabase.ScriptLoadMessage(flags, script);
return ExecuteAsync(msg, ResultProcessor.ScriptLoad);
}
public void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None) public void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None)
{ {
Message msg; Message msg;
...@@ -333,28 +400,27 @@ public Task<RedisChannel[]> SubscriptionChannelsAsync(RedisChannel pattern = def ...@@ -333,28 +400,27 @@ public Task<RedisChannel[]> SubscriptionChannelsAsync(RedisChannel pattern = def
return ExecuteAsync(msg, ResultProcessor.RedisChannelArray); return ExecuteAsync(msg, ResultProcessor.RedisChannelArray);
} }
public long SubscriptionSubscriberCount(RedisChannel channel, CommandFlags flags = CommandFlags.None) public long SubscriptionPatternCount(CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.NUMSUB, channel); var msg = Message.Create(-1, flags, RedisCommand.PUBLISH, RedisLiterals.NUMPAT);
return ExecuteSync(msg, ResultProcessor.Int64); return ExecuteSync(msg, ResultProcessor.Int64);
} }
public Task<long> SubscriptionSubscriberCountAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None) public Task<long> SubscriptionPatternCountAsync(CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.NUMSUB, channel); var msg = Message.Create(-1, flags, RedisCommand.PUBLISH, RedisLiterals.NUMPAT);
return ExecuteAsync(msg, ResultProcessor.Int64); return ExecuteAsync(msg, ResultProcessor.Int64);
} }
public long SubscriptionSubscriberCount(RedisChannel channel, CommandFlags flags = CommandFlags.None)
public long SubscriptionPatternCount(CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(-1, flags, RedisCommand.PUBLISH, RedisLiterals.NUMPAT); var msg = Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.NUMSUB, channel);
return ExecuteSync(msg, ResultProcessor.Int64); return ExecuteSync(msg, ResultProcessor.Int64);
} }
public Task<long> SubscriptionPatternCountAsync(CommandFlags flags = CommandFlags.None) public Task<long> SubscriptionSubscriberCountAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(-1, flags, RedisCommand.PUBLISH, RedisLiterals.NUMPAT); var msg = Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.NUMSUB, channel);
return ExecuteAsync(msg, ResultProcessor.Int64); return ExecuteAsync(msg, ResultProcessor.Int64);
} }
...@@ -382,7 +448,7 @@ internal static Message CreateSlaveOfMessage(EndPoint endpoint, CommandFlags fla ...@@ -382,7 +448,7 @@ internal static Message CreateSlaveOfMessage(EndPoint endpoint, CommandFlags fla
{ {
string hostRaw; string hostRaw;
int portRaw; int portRaw;
if(Format.TryGetHostPort(endpoint, out hostRaw, out portRaw)) if (Format.TryGetHostPort(endpoint, out hostRaw, out portRaw))
{ {
host = hostRaw; host = hostRaw;
port = portRaw; port = portRaw;
...@@ -394,6 +460,7 @@ internal static Message CreateSlaveOfMessage(EndPoint endpoint, CommandFlags fla ...@@ -394,6 +460,7 @@ internal static Message CreateSlaveOfMessage(EndPoint endpoint, CommandFlags fla
} }
return Message.Create(-1, flags, RedisCommand.SLAVEOF, host, port); return Message.Create(-1, flags, RedisCommand.SLAVEOF, host, port);
} }
internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> processor, ServerEndPoint server = null) internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> processor, ServerEndPoint server = null)
{ // inject our expected server automatically { // inject our expected server automatically
if (server == null) server = this.server; if (server == null) server = this.server;
...@@ -422,22 +489,6 @@ internal override T ExecuteSync<T>(Message message, ResultProcessor<T> processor ...@@ -422,22 +489,6 @@ internal override T ExecuteSync<T>(Message message, ResultProcessor<T> processor
return base.ExecuteSync<T>(message, processor, server); return base.ExecuteSync<T>(message, processor, server);
} }
private void FixFlags(Message message, ServerEndPoint server)
{
// since the server is specified explicitly, we don't want defaults
// to make the "non-preferred-endpoint" counters look artificially
// inflated; note we only change *prefer* options
switch(Message.GetMasterSlaveFlags(message.Flags))
{
case CommandFlags.PreferMaster:
if (server.IsSlave) message.SetPreferSlave();
break;
case CommandFlags.PreferSlave:
if (!server.IsSlave) message.SetPreferMaster();
break;
}
}
internal override RedisFeatures GetFeatures(int db, RedisKey key, CommandFlags flags, out ServerEndPoint server) internal override RedisFeatures GetFeatures(int db, RedisKey key, CommandFlags flags, out ServerEndPoint server)
{ {
server = this.server; server = this.server;
...@@ -454,31 +505,31 @@ internal void SlaveOf(EndPoint endpoint, CommandFlags flags = CommandFlags.None) ...@@ -454,31 +505,31 @@ internal void SlaveOf(EndPoint endpoint, CommandFlags flags = CommandFlags.None)
ExecuteSync(msg, ResultProcessor.DemandOK); ExecuteSync(msg, ResultProcessor.DemandOK);
} }
private void FixFlags(Message message, ServerEndPoint server)
public long DatabaseSize(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.DBSIZE);
return ExecuteSync(msg, ResultProcessor.Int64);
}
public Task<long> DatabaseSizeAsync(int database = 0, CommandFlags flags = CommandFlags.None)
{
var msg = Message.Create(database, flags, RedisCommand.DBSIZE);
return ExecuteAsync(msg, ResultProcessor.Int64);
}
public void FlushDatabase(int database = 0, CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(database, flags, RedisCommand.FLUSHDB); // since the server is specified explicitly, we don't want defaults
ExecuteSync(msg, ResultProcessor.DemandOK); // to make the "non-preferred-endpoint" counters look artificially
// inflated; note we only change *prefer* options
switch (Message.GetMasterSlaveFlags(message.Flags))
{
case CommandFlags.PreferMaster:
if (server.IsSlave) message.SetPreferSlave();
break;
case CommandFlags.PreferSlave:
if (!server.IsSlave) message.SetPreferMaster();
break;
}
} }
public Task FlushDatabaseAsync(int database = 0, CommandFlags flags = CommandFlags.None) Message GetSaveMessage(SaveType type, CommandFlags flags = CommandFlags.None)
{ {
var msg = Message.Create(database, flags, RedisCommand.FLUSHDB); switch(type)
return ExecuteAsync(msg, ResultProcessor.DemandOK); {
case SaveType.BackgroundRewriteAppendOnlyFile: return Message.Create(-1, flags, RedisCommand.BGREWRITEAOF);
case SaveType.BackgroundSave: return Message.Create(-1, flags, RedisCommand.BGSAVE);
default: throw new ArgumentOutOfRangeException("type");
}
} }
struct KeysScanResult struct KeysScanResult
{ {
public static readonly ResultProcessor<KeysScanResult> Processor = new KeysResultProcessor(); public static readonly ResultProcessor<KeysScanResult> Processor = new KeysResultProcessor();
...@@ -511,6 +562,34 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -511,6 +562,34 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
} }
} }
static class ScriptHash
{
static readonly byte[] hex = {
(byte)'0', (byte)'1', (byte)'2', (byte)'3', (byte)'4', (byte)'5', (byte)'6', (byte)'7',
(byte)'8', (byte)'9', (byte)'a', (byte)'b', (byte)'c', (byte)'d', (byte)'e', (byte)'f' };
public static RedisValue Encode(byte[] value)
{
if (value == null) return default(RedisValue);
byte[] result = new byte[value.Length * 2];
int offset = 0;
for (int i = 0; i < value.Length; i++)
{
int val = value[i];
result[offset++] = hex[val / 16];
result[offset++] = hex[val % 16];
}
return result;
}
public static RedisValue Hash(string value)
{
if (value == null) return default(RedisValue);
using (var sha1 = SHA1.Create())
{
var bytes = sha1.ComputeHash(Encoding.UTF8.GetBytes(value));
return Encode(bytes);
}
}
}
sealed class KeysScanIterator sealed class KeysScanIterator
{ {
internal const int DefaultPageSize = 10; internal const int DefaultPageSize = 10;
......
...@@ -26,7 +26,8 @@ abstract class ResultProcessor ...@@ -26,7 +26,8 @@ abstract class ResultProcessor
NullableDouble = new NullableDoubleProcessor(); NullableDouble = new NullableDoubleProcessor();
public static readonly ResultProcessor<byte[]> public static readonly ResultProcessor<byte[]>
ByteArray = new ByteArrayProcessor(); ByteArray = new ByteArrayProcessor(),
ScriptLoad = new ScriptLoadProcessor();
public static readonly ResultProcessor<ClusterConfiguration> public static readonly ResultProcessor<ClusterConfiguration>
ClusterNodes = new ClusterNodesProcessor(); ClusterNodes = new ClusterNodesProcessor();
...@@ -81,7 +82,7 @@ public static readonly TimeSpanProcessor ...@@ -81,7 +82,7 @@ public static readonly TimeSpanProcessor
SortedSetWithScores = new SortedSetWithScoresProcessor(); SortedSetWithScores = new SortedSetWithScoresProcessor();
public static readonly ResultProcessor<RedisResult> public static readonly ResultProcessor<RedisResult>
RedisResult = new RedisResultProcessor(); ScriptResult = new ScriptResultProcessor();
static readonly byte[] MOVED = Encoding.UTF8.GetBytes("MOVED "), ASK = Encoding.UTF8.GetBytes("ASK "); static readonly byte[] MOVED = Encoding.UTF8.GetBytes("MOVED "), ASK = Encoding.UTF8.GetBytes("ASK ");
...@@ -505,6 +506,14 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -505,6 +506,14 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
case ResultType.BulkString: case ResultType.BulkString:
SetResult(message, result.GetBoolean()); SetResult(message, result.GetBoolean());
return true; return true;
case ResultType.Array:
var items = result.GetItems();
if(items.Length == 1)
{ // treat an array of 1 like a single reply (for example, SCRIPT EXISTS)
SetResult(message, items[0].GetBoolean());
return true;
}
break;
} }
return false; return false;
} }
...@@ -524,6 +533,28 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -524,6 +533,28 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
} }
} }
internal sealed class ScriptLoadProcessor : ResultProcessor<byte[]>
{
// note that top-level error messages still get handled by SetResult, but nested errors
// (is that a thing?) will be wrapped in the RedisResult
protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result)
{
switch (result.Type)
{
case ResultType.BulkString:
var hash = result.GetBlob();
var sl = message as RedisDatabase.ScriptLoadMessage;
if (sl != null)
{
connection.Bridge.ServerEndPoint.AddScript(sl.Script, hash);
}
SetResult(message, hash);
return true;
}
return false;
}
}
sealed class ClusterNodesProcessor : ResultProcessor<ClusterConfiguration> sealed class ClusterNodesProcessor : ResultProcessor<ClusterConfiguration>
{ {
internal static ClusterConfiguration Parse(PhysicalConnection connection, string nodes) internal static ClusterConfiguration Parse(PhysicalConnection connection, string nodes)
...@@ -1012,8 +1043,19 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -1012,8 +1043,19 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
} }
} }
private class RedisResultProcessor : ResultProcessor<RedisResult> private class ScriptResultProcessor : ResultProcessor<RedisResult>
{ {
static readonly byte[] NOSCRIPT = Encoding.UTF8.GetBytes("NOSCRIPT ");
public override bool SetResult(PhysicalConnection connection, Message message, RawResult result)
{
if(result.Type == ResultType.Error && result.AssertStarts(NOSCRIPT))
{ // scripts are not flushed individually, so assume the entire script cache is toast ("SCRIPT FLUSH")
connection.Bridge.ServerEndPoint.FlushScripts();
}
// and apply usual processing for the rest
return base.SetResult(connection, message, result);
}
// note that top-level error messages still get handled by SetResult, but nested errors // note that top-level error messages still get handled by SetResult, but nested errors
// (is that a thing?) will be wrapped in the RedisResult // (is that a thing?) will be wrapped in the RedisResult
protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result)
......
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
...@@ -525,5 +526,25 @@ internal string GetProfile() ...@@ -525,5 +526,25 @@ internal string GetProfile()
if (tmp != null) tmp.AppendProfile(sb); if (tmp != null) tmp.AppendProfile(sb);
return sb.ToString(); return sb.ToString();
} }
private readonly Hashtable knownScripts = new Hashtable(StringComparer.Ordinal);
internal byte[] GetScriptHash(string script)
{
return (byte[])knownScripts[script];
}
internal void AddScript(string script, byte[] hash)
{
lock(knownScripts)
{
knownScripts[script] = hash;
}
}
internal void FlushScripts()
{
lock(knownScripts)
{
knownScripts.Clear();
}
}
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment