Commit ad868210 authored by kevin-montrose's avatar kevin-montrose

squash named-script-parameters branch into a single commit

parent dcf546f4
...@@ -15,4 +15,5 @@ Mono/ ...@@ -15,4 +15,5 @@ Mono/
*.orig *.orig
redis-cli.exe redis-cli.exe
Redis Configs/*.dat Redis Configs/*.dat
RedisQFork*.dat RedisQFork*.dat
\ No newline at end of file StackExchange.Redis.*.zip
\ No newline at end of file
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using NUnit.Framework; using NUnit.Framework;
using System.Linq; using System.Linq;
using System.IO;
namespace StackExchange.Redis.Tests namespace StackExchange.Redis.Tests
{ {
...@@ -169,5 +170,345 @@ public void TestCallByHash() ...@@ -169,5 +170,345 @@ public void TestCallByHash()
} }
} }
[Test]
public void SimpleLuaScript()
{
const string Script = "return @ident";
using(var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
var prepared = LuaScript.Prepare(Script);
var db = conn.GetDatabase();
{
var val = prepared.Evaluate(db, new { ident = "hello" });
Assert.AreEqual("hello", (string)val);
}
{
var val = prepared.Evaluate(db, new { ident = 123 });
Assert.AreEqual(123, (int)val);
}
{
var val = prepared.Evaluate(db, new { ident = 123L });
Assert.AreEqual(123L, (long)val);
}
{
var val = prepared.Evaluate(db, new { ident = 1.1 });
Assert.AreEqual(1.1, (double)val);
}
{
var val = prepared.Evaluate(db, new { ident = true });
Assert.AreEqual(true, (bool)val);
}
{
var val = prepared.Evaluate(db, new { ident = new byte[] { 4, 5, 6 } });
Assert.IsTrue(new byte [] { 4, 5, 6}.SequenceEqual((byte[])val));
}
}
}
[Test]
public void LuaScriptWithKeys()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
var script = LuaScript.Prepare(Script);
var db = conn.GetDatabase();
var p = new { key = (RedisKey)"testkey", value = 123 };
script.Evaluate(db, p);
var val = db.StringGet("testkey");
Assert.AreEqual(123, (int)val);
// no super clean way to extract this; so just abuse InternalsVisibleTo
RedisKey[] keys;
RedisValue[] args;
script.ExtractParameters(p, null, out keys, out args);
Assert.IsNotNull(keys);
Assert.AreEqual(1, keys.Length);
Assert.AreEqual("testkey", (string)keys[0]);
}
}
[Test]
public void NoInlineReplacement()
{
const string Script = "redis.call('set', @key, 'hello@example')";
using (var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
var script = LuaScript.Prepare(Script);
Assert.AreEqual("redis.call('set', ARGV[1], 'hello@example')", script.ExecutableScript);
var db = conn.GetDatabase();
var p = new { key = (RedisKey)"key" };
script.Evaluate(db, p);
var val = db.StringGet("key");
Assert.AreEqual("hello@example", (string)val);
}
}
[Test]
public void EscapeReplacement()
{
const string Script = "redis.call('set', @key, @@escapeMe)";
var script = LuaScript.Prepare(Script);
Assert.AreEqual("redis.call('set', ARGV[1], @escapeMe)", script.ExecutableScript);
}
[Test]
public void SimpleLoadedLuaScript()
{
const string Script = "return @ident";
using (var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
var prepared = LuaScript.Prepare(Script);
var loaded = prepared.Load(server);
var db = conn.GetDatabase();
{
var val = loaded.Evaluate(db, new { ident = "hello" });
Assert.AreEqual("hello", (string)val);
}
{
var val = loaded.Evaluate(db, new { ident = 123 });
Assert.AreEqual(123, (int)val);
}
{
var val = loaded.Evaluate(db, new { ident = 123L });
Assert.AreEqual(123L, (long)val);
}
{
var val = loaded.Evaluate(db, new { ident = 1.1 });
Assert.AreEqual(1.1, (double)val);
}
{
var val = loaded.Evaluate(db, new { ident = true });
Assert.AreEqual(true, (bool)val);
}
{
var val = loaded.Evaluate(db, new { ident = new byte[] { 4, 5, 6 } });
Assert.IsTrue(new byte[] { 4, 5, 6 }.SequenceEqual((byte[])val));
}
}
}
[Test]
public void LoadedLuaScriptWithKeys()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var server = conn.GetServer(PrimaryServer, PrimaryPort);
server.FlushAllDatabases();
server.ScriptFlush();
var script = LuaScript.Prepare(Script);
var prepared = script.Load(server);
var db = conn.GetDatabase();
var p = new { key = (RedisKey)"testkey", value = 123 };
prepared.Evaluate(db, p);
var val = db.StringGet("testkey");
Assert.AreEqual(123, (int)val);
// no super clean way to extract this; so just abuse InternalsVisibleTo
RedisKey[] keys;
RedisValue[] args;
prepared.Original.ExtractParameters(p, null, out keys, out args);
Assert.IsNotNull(keys);
Assert.AreEqual(1, keys.Length);
Assert.AreEqual("testkey", (string)keys[0]);
}
}
[Test]
public void PurgeLuaScriptCache()
{
const string Script = "redis.call('set', @PurgeLuaScriptCacheKey, @PurgeLuaScriptCacheValue)";
var first = LuaScript.Prepare(Script);
var fromCache = LuaScript.Prepare(Script);
Assert.IsTrue(object.ReferenceEquals(first, fromCache));
LuaScript.PurgeCache();
var shouldBeNew = LuaScript.Prepare(Script);
Assert.IsFalse(object.ReferenceEquals(first, shouldBeNew));
}
static void _PurgeLuaScriptOnFinalize(string script)
{
var first = LuaScript.Prepare(script);
var fromCache = LuaScript.Prepare(script);
Assert.IsTrue(object.ReferenceEquals(first, fromCache));
Assert.AreEqual(1, LuaScript.GetCachedScriptCount());
}
[Test]
public void PurgeLuaScriptOnFinalize()
{
const string Script = "redis.call('set', @PurgeLuaScriptOnFinalizeKey, @PurgeLuaScriptOnFinalizeValue)";
LuaScript.PurgeCache();
Assert.AreEqual(0, LuaScript.GetCachedScriptCount());
// This has to be a separate method to guarantee that the created LuaScript objects go out of scope,
// and are thus available to be GC'd
_PurgeLuaScriptOnFinalize(Script);
GC.Collect(2, GCCollectionMode.Forced, blocking: true);
GC.WaitForPendingFinalizers();
Assert.AreEqual(0, LuaScript.GetCachedScriptCount());
var shouldBeNew = LuaScript.Prepare(Script);
Assert.AreEqual(1, LuaScript.GetCachedScriptCount());
}
[Test]
public void IDatabaseLuaScriptConvenienceMethods()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var script = LuaScript.Prepare(Script);
var db = conn.GetDatabase();
db.ScriptEvaluate(script, new { key = (RedisKey)"key", value = "value" });
var val = db.StringGet("key");
Assert.AreEqual("value", (string)val);
var prepared = script.Load(conn.GetServer(conn.GetEndPoints()[0]));
db.ScriptEvaluate(prepared, new { key = (RedisKey)"key2", value = "value2" });
var val2 = db.StringGet("key2");
Assert.AreEqual("value2", (string)val2);
}
}
[Test]
public void IServerLuaScriptConvenienceMethods()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var script = LuaScript.Prepare(Script);
var server = conn.GetServer(conn.GetEndPoints()[0]);
var db = conn.GetDatabase();
var prepared = server.ScriptLoad(script);
db.ScriptEvaluate(prepared, new { key = (RedisKey)"key3", value = "value3" });
var val = db.StringGet("key3");
Assert.AreEqual("value3", (string)val);
}
}
[Test]
public void LuaScriptPrefixedKeys()
{
const string Script = "redis.call('set', @key, @value)";
var prepared = LuaScript.Prepare(Script);
var p = new { key = (RedisKey)"key", value = "hello" };
// no super clean way to extract this; so just abuse InternalsVisibleTo
RedisKey[] keys;
RedisValue[] args;
prepared.ExtractParameters(p, "prefix-", out keys, out args);
Assert.IsNotNull(keys);
Assert.AreEqual(1, keys.Length);
Assert.AreEqual("prefix-key", (string)keys[0]);
Assert.AreEqual(2, args.Length);
Assert.AreEqual("prefix-key", (string)args[0]);
Assert.AreEqual("hello", (string)args[1]);
}
[Test]
public void LuaScriptWithWrappedDatabase()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var db = conn.GetDatabase(0);
var wrappedDb = StackExchange.Redis.KeyspaceIsolation.DatabaseExtensions.WithKeyPrefix(db, "prefix-");
var prepared = LuaScript.Prepare(Script);
wrappedDb.ScriptEvaluate(prepared, new { key = (RedisKey)"mykey", value = 123 });
var val1 = wrappedDb.StringGet("mykey");
Assert.AreEqual(123, (int)val1);
var val2 = db.StringGet("prefix-mykey");
Assert.AreEqual(123, (int)val2);
var val3 = db.StringGet("mykey");
Assert.IsTrue(val3.IsNull);
}
}
[Test]
public void LoadedLuaScriptWithWrappedDatabase()
{
const string Script = "redis.call('set', @key, @value)";
using (var conn = Create(allowAdmin: true))
{
var db = conn.GetDatabase(0);
var wrappedDb = StackExchange.Redis.KeyspaceIsolation.DatabaseExtensions.WithKeyPrefix(db, "prefix2-");
var server = conn.GetServer(conn.GetEndPoints()[0]);
var prepared = LuaScript.Prepare(Script).Load(server);
wrappedDb.ScriptEvaluate(prepared, new { key = (RedisKey)"mykey", value = 123 });
var val1 = wrappedDb.StringGet("mykey");
Assert.AreEqual(123, (int)val1);
var val2 = db.StringGet("prefix2-mykey");
Assert.AreEqual(123, (int)val2);
var val3 = db.StringGet("mykey");
Assert.IsTrue(val3.IsNull);
}
}
} }
} }
...@@ -74,6 +74,7 @@ ...@@ -74,6 +74,7 @@
<Compile Include="StackExchange\Redis\HashEntry.cs" /> <Compile Include="StackExchange\Redis\HashEntry.cs" />
<Compile Include="StackExchange\Redis\InternalErrorEventArgs.cs" /> <Compile Include="StackExchange\Redis\InternalErrorEventArgs.cs" />
<Compile Include="StackExchange\Redis\MigrateOptions.cs" /> <Compile Include="StackExchange\Redis\MigrateOptions.cs" />
<Compile Include="StackExchange\Redis\LuaScript.cs" />
<Compile Include="StackExchange\Redis\RedisChannel.cs" /> <Compile Include="StackExchange\Redis\RedisChannel.cs" />
<Compile Include="StackExchange\Redis\Bitwise.cs" /> <Compile Include="StackExchange\Redis\Bitwise.cs" />
<Compile Include="StackExchange\Redis\ClientFlags.cs" /> <Compile Include="StackExchange\Redis\ClientFlags.cs" />
...@@ -139,6 +140,7 @@ ...@@ -139,6 +140,7 @@
<Compile Include="StackExchange\Redis\ResultProcessor.cs" /> <Compile Include="StackExchange\Redis\ResultProcessor.cs" />
<Compile Include="StackExchange\Redis\RedisSubscriber.cs" /> <Compile Include="StackExchange\Redis\RedisSubscriber.cs" />
<Compile Include="StackExchange\Redis\ResultType.cs" /> <Compile Include="StackExchange\Redis\ResultType.cs" />
<Compile Include="StackExchange\Redis\ScriptParameterMapper.cs" />
<Compile Include="StackExchange\Redis\ServerCounters.cs" /> <Compile Include="StackExchange\Redis\ServerCounters.cs" />
<Compile Include="StackExchange\Redis\ServerEndPoint.cs" /> <Compile Include="StackExchange\Redis\ServerEndPoint.cs" />
<Compile Include="StackExchange\Redis\ServerSelectionStrategy.cs" /> <Compile Include="StackExchange\Redis\ServerSelectionStrategy.cs" />
......
...@@ -68,6 +68,7 @@ ...@@ -68,6 +68,7 @@
<Compile Include="StackExchange\Redis\HashEntry.cs" /> <Compile Include="StackExchange\Redis\HashEntry.cs" />
<Compile Include="StackExchange\Redis\InternalErrorEventArgs.cs" /> <Compile Include="StackExchange\Redis\InternalErrorEventArgs.cs" />
<Compile Include="StackExchange\Redis\MigrateOptions.cs" /> <Compile Include="StackExchange\Redis\MigrateOptions.cs" />
<Compile Include="StackExchange\Redis\LuaScript.cs" />
<Compile Include="StackExchange\Redis\RedisChannel.cs" /> <Compile Include="StackExchange\Redis\RedisChannel.cs" />
<Compile Include="StackExchange\Redis\Bitwise.cs" /> <Compile Include="StackExchange\Redis\Bitwise.cs" />
<Compile Include="StackExchange\Redis\ClientFlags.cs" /> <Compile Include="StackExchange\Redis\ClientFlags.cs" />
...@@ -133,6 +134,7 @@ ...@@ -133,6 +134,7 @@
<Compile Include="StackExchange\Redis\ResultProcessor.cs" /> <Compile Include="StackExchange\Redis\ResultProcessor.cs" />
<Compile Include="StackExchange\Redis\RedisSubscriber.cs" /> <Compile Include="StackExchange\Redis\RedisSubscriber.cs" />
<Compile Include="StackExchange\Redis\ResultType.cs" /> <Compile Include="StackExchange\Redis\ResultType.cs" />
<Compile Include="StackExchange\Redis\ScriptParameterMapper.cs" />
<Compile Include="StackExchange\Redis\ServerCounters.cs" /> <Compile Include="StackExchange\Redis\ServerCounters.cs" />
<Compile Include="StackExchange\Redis\ServerEndPoint.cs" /> <Compile Include="StackExchange\Redis\ServerEndPoint.cs" />
<Compile Include="StackExchange\Redis\ServerSelectionStrategy.cs" /> <Compile Include="StackExchange\Redis\ServerSelectionStrategy.cs" />
......
...@@ -51,6 +51,30 @@ public static ConfiguredTaskAwaitable<T> ForAwait<T>(this Task<T> task) ...@@ -51,6 +51,30 @@ public static ConfiguredTaskAwaitable<T> ForAwait<T>(this Task<T> task)
/// </summary> /// </summary>
public sealed partial class ConnectionMultiplexer : IDisposable public sealed partial class ConnectionMultiplexer : IDisposable
{ {
private static TaskFactory _factory = null;
/// <summary>
/// Provides a way of overriding the default Task Factory. If not set, it will use the default Task.Factory.
/// Useful when top level code sets it's own factory which may interfere with Redis queries.
/// </summary>
public static TaskFactory Factory
{
get
{
if (_factory != null)
{
return _factory;
}
return Task.Factory;
}
set
{
_factory = value;
}
}
/// <summary> /// <summary>
/// Get summary statistics associates with this server /// Get summary statistics associates with this server
/// </summary> /// </summary>
...@@ -738,7 +762,7 @@ private static ConnectionMultiplexer ConnectImpl(Func<ConnectionMultiplexer> mul ...@@ -738,7 +762,7 @@ private static ConnectionMultiplexer ConnectImpl(Func<ConnectionMultiplexer> mul
killMe = muxer; killMe = muxer;
// note that task has timeouts internally, so it might take *just over* the regular timeout // note that task has timeouts internally, so it might take *just over* the regular timeout
// wrap into task to force async execution // wrap into task to force async execution
var task = Task.Factory.StartNew(() => { return muxer.ReconfigureAsync(true, false, log, null, "connect").Result; }); var task = Factory.StartNew(() => { return muxer.ReconfigureAsync(true, false, log, null, "connect").Result; });
if (!task.Wait(muxer.SyncConnectTimeout(true))) if (!task.Wait(muxer.SyncConnectTimeout(true)))
{ {
...@@ -1472,9 +1496,38 @@ private ServerEndPoint SelectServerByElection(ServerEndPoint[] servers, string e ...@@ -1472,9 +1496,38 @@ private ServerEndPoint SelectServerByElection(ServerEndPoint[] servers, string e
return servers[i]; return servers[i];
} }
LogLocked(log, "...but we couldn't find that"); LogLocked(log, "...but we couldn't find that");
var deDottedEndpoint = DeDotifyHost(endpoint);
for (int i = 0; i < servers.Length; i++)
{
if (string.Equals(DeDotifyHost(Format.ToString(servers[i].EndPoint)), deDottedEndpoint, StringComparison.OrdinalIgnoreCase))
{
LogLocked(log, "...but we did find instead: {0}", deDottedEndpoint);
return servers[i];
}
}
return null; return null;
} }
static string DeDotifyHost(string input)
{
if (string.IsNullOrWhiteSpace(input)) return input; // GIGO
if (!char.IsLetter(input[0])) return input; // need first char to be alpha for this to work
int periodPosition = input.IndexOf('.');
if (periodPosition <= 0) return input; // no period or starts with a period? nothing useful to split
int colonPosition = input.IndexOf(':');
if (colonPosition > 0)
{ // has a port specifier
return input.Substring(0, periodPosition) + input.Substring(colonPosition);
}
else
{
return input.Substring(0, periodPosition);
}
}
internal void UpdateClusterRange(ClusterConfiguration configuration) internal void UpdateClusterRange(ClusterConfiguration configuration)
{ {
if (configuration == null) return; if (configuration == null) return;
...@@ -1850,7 +1903,6 @@ public Task<long> PublishReconfigureAsync(CommandFlags flags = CommandFlags.None ...@@ -1850,7 +1903,6 @@ public Task<long> PublishReconfigureAsync(CommandFlags flags = CommandFlags.None
if (channel == null) return CompletedTask<long>.Default(null); if (channel == null) return CompletedTask<long>.Default(null);
return GetSubscriber().PublishAsync(channel, RedisLiterals.Wildcard, flags); return GetSubscriber().PublishAsync(channel, RedisLiterals.Wildcard, flags);
} }
} }
} }
...@@ -452,6 +452,7 @@ public interface IDatabase : IRedis, IDatabaseAsync ...@@ -452,6 +452,7 @@ public interface IDatabase : IRedis, IDatabaseAsync
/// <returns>the number of clients that received the message.</returns> /// <returns>the number of clients that received the message.</returns>
/// <remarks>http://redis.io/commands/publish</remarks> /// <remarks>http://redis.io/commands/publish</remarks>
long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None); long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None);
/// <summary> /// <summary>
/// Execute a Lua script against the server /// Execute a Lua script against the server
/// </summary> /// </summary>
...@@ -464,7 +465,20 @@ public interface IDatabase : IRedis, IDatabaseAsync ...@@ -464,7 +465,20 @@ public interface IDatabase : IRedis, IDatabaseAsync
/// </summary> /// </summary>
/// <remarks>http://redis.io/commands/evalsha</remarks> /// <remarks>http://redis.io/commands/evalsha</remarks>
/// <returns>A dynamic representation of the script's result</returns> /// <returns>A dynamic representation of the script's result</returns>
RedisResult ScriptEvaluate(byte[] hash, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None); RedisResult ScriptEvaluate(byte[] hash, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Execute a lua script against the server, using previously prepared script.
/// Named parameters, if any, are provided by the `parameters` object.
/// </summary>
RedisResult ScriptEvaluate(LuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Execute a lua script against the server, using previously prepared and loaded script.
/// This method sends only the SHA1 hash of the lua script to Redis.
/// Named parameters, if any, are provided by the `parameters` object.
/// </summary>
RedisResult ScriptEvaluate(LoadedLuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None);
/// <summary> /// <summary>
/// Add the specified member to the set stored at key. Specified members that are already a member of this set are ignored. If key does not exist, a new set is created before adding the specified members. /// Add the specified member to the set stored at key. Specified members that are already a member of this set are ignored. If key does not exist, a new set is created before adding the specified members.
......
...@@ -437,7 +437,20 @@ public interface IDatabaseAsync : IRedisAsync ...@@ -437,7 +437,20 @@ public interface IDatabaseAsync : IRedisAsync
/// </summary> /// </summary>
/// <remarks>http://redis.io/commands/evalsha</remarks> /// <remarks>http://redis.io/commands/evalsha</remarks>
/// <returns>A dynamic representation of the script's result</returns> /// <returns>A dynamic representation of the script's result</returns>
Task<RedisResult> ScriptEvaluateAsync(byte[] hash, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None); Task<RedisResult> ScriptEvaluateAsync(byte[] hash, RedisKey[] keys = null, RedisValue[] values = null, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Execute a lua script against the server, using previously prepared script.
/// Named parameters, if any, are provided by the `parameters` object.
/// </summary>
Task<RedisResult> ScriptEvaluateAsync(LuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Execute a lua script against the server, using previously prepared and loaded script.
/// This method sends only the SHA1 hash of the lua script to Redis.
/// Named parameters, if any, are provided by the `parameters` object.
/// </summary>
Task<RedisResult> ScriptEvaluateAsync(LoadedLuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None);
/// <summary> /// <summary>
/// Add the specified member to the set stored at key. Specified members that are already a member of this set are ignored. If key does not exist, a new set is created before adding the specified members. /// Add the specified member to the set stored at key. Specified members that are already a member of this set are ignored. If key does not exist, a new set is created before adding the specified members.
......
...@@ -312,12 +312,22 @@ public partial interface IServer : IRedis ...@@ -312,12 +312,22 @@ public partial interface IServer : IRedis
/// <summary> /// <summary>
/// Explicitly defines a script on the server /// Explicitly defines a script on the server
/// </summary> /// </summary>
byte[] ScriptLoad(string script, CommandFlags flags = CommandFlags.None); byte[] ScriptLoad(string script, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Explicitly defines a script on the server
/// </summary>
LoadedLuaScript ScriptLoad(LuaScript script, CommandFlags flags = CommandFlags.None);
/// <summary> /// <summary>
/// Explicitly defines a script on the server /// Explicitly defines a script on the server
/// </summary> /// </summary>
Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandFlags.None); Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Explicitly defines a script on the server
/// </summary>
Task<LoadedLuaScript> ScriptLoadAsync(LuaScript script, CommandFlags flags = CommandFlags.None);
/// <summary>Asks the redis server to shutdown, killing all connections. Please FULLY read the notes on the SHUTDOWN command.</summary> /// <summary>Asks the redis server to shutdown, killing all connections. Please FULLY read the notes on the SHUTDOWN command.</summary>
/// <remarks>http://redis.io/commands/shutdown</remarks> /// <remarks>http://redis.io/commands/shutdown</remarks>
......
...@@ -323,6 +323,18 @@ public RedisResult ScriptEvaluate(string script, RedisKey[] keys = null, RedisVa ...@@ -323,6 +323,18 @@ public RedisResult ScriptEvaluate(string script, RedisKey[] keys = null, RedisVa
return this.Inner.ScriptEvaluate(script, this.ToInner(keys), values, flags); return this.Inner.ScriptEvaluate(script, this.ToInner(keys), values, flags);
} }
public RedisResult ScriptEvaluate(LuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
// TODO: The return value could contain prefixed keys. It might make sense to 'unprefix' those?
return script.Evaluate(this.Inner, parameters, Prefix, flags);
}
public RedisResult ScriptEvaluate(LoadedLuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
// TODO: The return value could contain prefixed keys. It might make sense to 'unprefix' those?
return script.Evaluate(this.Inner, parameters, Prefix, flags);
}
public long SetAdd(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) public long SetAdd(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None)
{ {
return this.Inner.SetAdd(this.ToInner(key), values, flags); return this.Inner.SetAdd(this.ToInner(key), values, flags);
......
...@@ -334,6 +334,16 @@ public Task<RedisResult> ScriptEvaluateAsync(string script, RedisKey[] keys = nu ...@@ -334,6 +334,16 @@ public Task<RedisResult> ScriptEvaluateAsync(string script, RedisKey[] keys = nu
return this.Inner.ScriptEvaluateAsync(script, this.ToInner(keys), values, flags); return this.Inner.ScriptEvaluateAsync(script, this.ToInner(keys), values, flags);
} }
public Task<RedisResult> ScriptEvaluateAsync(LuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
throw new NotImplementedException();
}
public Task<RedisResult> ScriptEvaluateAsync(LoadedLuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
throw new NotImplementedException();
}
public Task<long> SetAddAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) public Task<long> SetAddAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None)
{ {
return this.Inner.SetAddAsync(this.ToInner(key), values, flags); return this.Inner.SetAddAsync(this.ToInner(key), values, flags);
......
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// Represents a Lua script that can be executed on Redis.
///
/// Unlike normal Redis Lua scripts, LuaScript can have named parameters (prefixed by a @).
/// Public fields and properties of the passed in object are treated as parameters.
///
/// Parameters of type RedisKey are sent to Redis as KEY (http://redis.io/commands/eval) in addition to arguments,
/// so as to play nicely with Redis Cluster.
///
/// All members of this class are thread safe.
/// </summary>
public sealed class LuaScript
{
// Since the mapping of "script text" -> LuaScript doesn't depend on any particular details of
// the redis connection itself, this cache is global.
static readonly ConcurrentDictionary<string, WeakReference> Cache = new ConcurrentDictionary<string, WeakReference>();
/// <summary>
/// The original Lua script that was used to create this.
/// </summary>
public string OriginalScript { get; private set; }
/// <summary>
/// The Lua script that will actually be sent to Redis for execution.
///
/// All @-prefixed parameter names have been replaced at this point.
/// </summary>
public string ExecutableScript { get; private set; }
// Arguments are in the order they have to passed to the script in
internal string[] Arguments { get; private set; }
bool HasArguments { get { return Arguments != null && Arguments.Length > 0; } }
Hashtable ParameterMappers;
internal LuaScript(string originalScript, string executableScript, string[] arguments)
{
OriginalScript = originalScript;
ExecutableScript = executableScript;
Arguments = arguments;
if (HasArguments)
{
ParameterMappers = new Hashtable();
}
}
/// <summary>
/// Finalizer, used to prompt cleanups of the script cache when
/// a LuaScript reference goes out of scope.
/// </summary>
~LuaScript()
{
try
{
WeakReference ignored;
Cache.TryRemove(OriginalScript, out ignored);
}
catch { }
}
/// <summary>
/// Invalidates the internal cache of LuaScript objects.
/// Existing LuaScripts will continue to work, but future calls to LuaScript.Prepare
/// return a new LuaScript instance.
/// </summary>
public static void PurgeCache()
{
Cache.Clear();
}
/// <summary>
/// Returns the number of cached LuaScripts.
/// </summary>
public static int GetCachedScriptCount()
{
return Cache.Count;
}
/// <summary>
/// Prepares a Lua script with named parameters to be run against any Redis instance.
/// </summary>
public static LuaScript Prepare(string script)
{
LuaScript ret;
WeakReference weakRef;
if (!Cache.TryGetValue(script, out weakRef) || (ret = (LuaScript)weakRef.Target) == null)
{
ret = ScriptParameterMapper.PrepareScript(script);
Cache[script] = new WeakReference(ret);
}
return ret;
}
internal void ExtractParameters(object ps, RedisKey? keyPrefix, out RedisKey[] keys, out RedisValue[] args)
{
if (HasArguments)
{
if (ps == null) throw new ArgumentNullException("ps", "Script requires parameters");
var psType = ps.GetType();
var mapper = (Func<object, RedisKey?, ScriptParameterMapper.ScriptParameters>)ParameterMappers[psType];
if (ps != null && mapper == null)
{
lock (ParameterMappers)
{
mapper = (Func<object, RedisKey?, ScriptParameterMapper.ScriptParameters>)ParameterMappers[psType];
if (mapper == null)
{
string missingMember;
string badMemberType;
if(!ScriptParameterMapper.IsValidParameterHash(psType, this, out missingMember, out badMemberType))
{
if (missingMember != null)
{
throw new ArgumentException("ps", "Expected [" + missingMember + "] to be a field or gettable property on [" + psType.FullName + "]");
}
throw new ArgumentException("ps", "Expected [" + badMemberType + "] on [" + psType.FullName + "] to be convertable to a RedisValue");
}
ParameterMappers[psType] = mapper = ScriptParameterMapper.GetParameterExtractor(psType, this);
}
}
}
var mapped = mapper(ps, keyPrefix);
keys = mapped.Keys;
args = mapped.Arguments;
}
else
{
keys = null;
args = null;
}
}
/// <summary>
/// Evaluates this LuaScript against the given database, extracting parameters from the passed in object if any.
/// </summary>
public RedisResult Evaluate(IDatabase db, object ps = null, RedisKey? withKeyPrefix = null, CommandFlags flags = CommandFlags.None)
{
RedisKey[] keys;
RedisValue[] args;
ExtractParameters(ps, withKeyPrefix, out keys, out args);
return db.ScriptEvaluate(ExecutableScript, keys, args, flags);
}
/// <summary>
/// Evaluates this LuaScript against the given database, extracting parameters from the passed in object if any.
/// </summary>
public Task<RedisResult> EvaluateAsync(IDatabaseAsync db, object ps = null, RedisKey? withKeyPrefix = null, CommandFlags flags = CommandFlags.None)
{
RedisKey[] keys;
RedisValue[] args;
ExtractParameters(ps, withKeyPrefix, out keys, out args);
return db.ScriptEvaluateAsync(ExecutableScript, keys, args, flags);
}
/// <summary>
/// Loads this LuaScript into the given IServer so it can be run with it's SHA1 hash, instead of
/// passing the full script on each Evaluate or EvaluateAsync call.
///
/// Note: the FireAndForget command flag cannot be set
/// </summary>
public LoadedLuaScript Load(IServer server, CommandFlags flags = CommandFlags.None)
{
if (flags.HasFlag(CommandFlags.FireAndForget))
{
throw new ArgumentOutOfRangeException("flags", "Loading a script cannot be FireAndForget");
}
var hash = server.ScriptLoad(ExecutableScript, flags);
return new LoadedLuaScript(this, hash);
}
/// <summary>
/// Loads this LuaScript into the given IServer so it can be run with it's SHA1 hash, instead of
/// passing the full script on each Evaluate or EvaluateAsync call.
///
/// Note: the FireAndForget command flag cannot be set
/// </summary>
public async Task<LoadedLuaScript> LoadAsync(IServer server, CommandFlags flags = CommandFlags.None)
{
if (flags.HasFlag(CommandFlags.FireAndForget))
{
throw new ArgumentOutOfRangeException("flags", "Loading a script cannot be FireAndForget");
}
var hash = await server.ScriptLoadAsync(ExecutableScript, flags);
return new LoadedLuaScript(this, hash);
}
}
/// <summary>
/// Represents a Lua script that can be executed on Redis.
///
/// Unlike LuaScript, LoadedLuaScript sends the hash of it's ExecutableScript to Redis rather than pass
/// the whole script on each call. This requires that the script be loaded into Redis before it is used.
///
/// To create a LoadedLuaScript first create a LuaScript via LuaScript.Prepare(string), then
/// call Load(IServer, CommandFlags) on the returned LuaScript.
///
/// Unlike normal Redis Lua scripts, LoadedLuaScript can have named parameters (prefixed by a @).
/// Public fields and properties of the passed in object are treated as parameters.
///
/// Parameters of type RedisKey are sent to Redis as KEY (http://redis.io/commands/eval) in addition to arguments,
/// so as to play nicely with Redis Cluster.
///
/// All members of this class are thread safe.
/// </summary>
public sealed class LoadedLuaScript
{
/// <summary>
/// The original script that was used to create this LoadedLuaScript.
/// </summary>
public string OriginalScript { get { return Original.OriginalScript; } }
/// <summary>
/// The script that will actually be sent to Redis for execution.
/// </summary>
public string ExecutableScript { get { return Original.ExecutableScript; } }
/// <summary>
/// The SHA1 hash of ExecutableScript.
///
/// This is sent to Redis instead of ExecutableScript during Evaluate and EvaluateAsync calls.
/// </summary>
public byte[] Hash { get; private set; }
// internal for testing purposes only
internal LuaScript Original;
internal LoadedLuaScript(LuaScript original, byte[] hash)
{
Original = original;
Hash = hash;
}
/// <summary>
/// Evaluates this LoadedLuaScript against the given database, extracting parameters for the passed in object if any.
///
/// This method sends the SHA1 hash of the ExecutableScript instead of the script itself. If the script has not
/// been loaded into the passed Redis instance it will fail.
/// </summary>
public RedisResult Evaluate(IDatabase db, object ps = null, RedisKey? withKeyPrefix = null, CommandFlags flags = CommandFlags.None)
{
RedisKey[] keys;
RedisValue[] args;
Original.ExtractParameters(ps, withKeyPrefix, out keys, out args);
return db.ScriptEvaluate(Hash, keys, args, flags);
}
/// <summary>
/// Evaluates this LoadedLuaScript against the given database, extracting parameters for the passed in object if any.
///
/// This method sends the SHA1 hash of the ExecutableScript instead of the script itself. If the script has not
/// been loaded into the passed Redis instance it will fail.
/// </summary>
public Task<RedisResult> EvaluateAsync(IDatabaseAsync db, object ps = null, RedisKey? withKeyPrefix = null, CommandFlags flags = CommandFlags.None)
{
RedisKey[] keys;
RedisValue[] args;
Original.ExtractParameters(ps, withKeyPrefix, out keys, out args);
return db.ScriptEvaluateAsync(Hash, keys, args, flags);
}
}
}
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Net.Security; using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Security.Authentication; using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
internal sealed partial class PhysicalConnection : IDisposable, ISocketCallback internal sealed partial class PhysicalConnection : IDisposable, ISocketCallback
{ {
internal readonly byte[] ChannelPrefix; internal readonly byte[] ChannelPrefix;
private const int DefaultRedisDatabaseCount = 16; private const int DefaultRedisDatabaseCount = 16;
private static readonly byte[] Crlf = Encoding.ASCII.GetBytes("\r\n"); private static readonly byte[] Crlf = Encoding.ASCII.GetBytes("\r\n");
static readonly AsyncCallback endRead = result => static readonly AsyncCallback endRead = result =>
{ {
PhysicalConnection physical; PhysicalConnection physical;
if (result.CompletedSynchronously || (physical = result.AsyncState as PhysicalConnection) == null) return; if (result.CompletedSynchronously || (physical = result.AsyncState as PhysicalConnection) == null) return;
try try
{ {
physical.multiplexer.Trace("Completed synchronously: processing in callback", physical.physicalName); physical.multiplexer.Trace("Completed synchronously: processing in callback", physical.physicalName);
if (physical.EndReading(result)) physical.BeginReading(); if (physical.EndReading(result)) physical.BeginReading();
} }
catch (Exception ex) catch (Exception ex)
{ {
physical.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); physical.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
} }
}; };
private static readonly byte[] message = Encoding.UTF8.GetBytes("message"), pmessage = Encoding.UTF8.GetBytes("pmessage"); private static readonly byte[] message = Encoding.UTF8.GetBytes("message"), pmessage = Encoding.UTF8.GetBytes("pmessage");
static readonly Message[] ReusableChangeDatabaseCommands = Enumerable.Range(0, DefaultRedisDatabaseCount).Select( static readonly Message[] ReusableChangeDatabaseCommands = Enumerable.Range(0, DefaultRedisDatabaseCount).Select(
i => Message.Create(i, CommandFlags.FireAndForget, RedisCommand.SELECT)).ToArray(); i => Message.Create(i, CommandFlags.FireAndForget, RedisCommand.SELECT)).ToArray();
private static readonly Message private static readonly Message
ReusableReadOnlyCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.READONLY), ReusableReadOnlyCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.READONLY),
ReusableReadWriteCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.READWRITE); ReusableReadWriteCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.READWRITE);
private static int totalCount; private static int totalCount;
private readonly PhysicalBridge bridge; private readonly PhysicalBridge bridge;
private readonly ConnectionType connectionType; private readonly ConnectionType connectionType;
private readonly ConnectionMultiplexer multiplexer; private readonly ConnectionMultiplexer multiplexer;
// things sent to this physical, but not yet received // things sent to this physical, but not yet received
private readonly Queue<Message> outstanding = new Queue<Message>(); private readonly Queue<Message> outstanding = new Queue<Message>();
readonly string physicalName; readonly string physicalName;
volatile int currentDatabase = 0; volatile int currentDatabase = 0;
ReadMode currentReadMode = ReadMode.NotSpecified; ReadMode currentReadMode = ReadMode.NotSpecified;
int failureReported; int failureReported;
byte[] ioBuffer = new byte[512]; byte[] ioBuffer = new byte[512];
int ioBufferBytes = 0; int ioBufferBytes = 0;
int lastWriteTickCount, lastReadTickCount, lastBeatTickCount; int lastWriteTickCount, lastReadTickCount, lastBeatTickCount;
int firstUnansweredWriteTickCount; int firstUnansweredWriteTickCount;
private Stream netStream, outStream; private Stream netStream, outStream;
private SocketToken socketToken; private SocketToken socketToken;
public PhysicalConnection(PhysicalBridge bridge) public PhysicalConnection(PhysicalBridge bridge)
{ {
lastWriteTickCount = lastReadTickCount = Environment.TickCount; lastWriteTickCount = lastReadTickCount = Environment.TickCount;
lastBeatTickCount = 0; lastBeatTickCount = 0;
this.connectionType = bridge.ConnectionType; this.connectionType = bridge.ConnectionType;
this.multiplexer = bridge.Multiplexer; this.multiplexer = bridge.Multiplexer;
this.ChannelPrefix = multiplexer.RawConfig.ChannelPrefix; this.ChannelPrefix = multiplexer.RawConfig.ChannelPrefix;
if (this.ChannelPrefix != null && this.ChannelPrefix.Length == 0) this.ChannelPrefix = null; // null tests are easier than null+empty if (this.ChannelPrefix != null && this.ChannelPrefix.Length == 0) this.ChannelPrefix = null; // null tests are easier than null+empty
var endpoint = bridge.ServerEndPoint.EndPoint; var endpoint = bridge.ServerEndPoint.EndPoint;
physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint); physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint);
this.bridge = bridge; this.bridge = bridge;
OnCreateEcho(); OnCreateEcho();
} }
public void BeginConnect() public void BeginConnect()
{ {
Thread.VolatileWrite(ref firstUnansweredWriteTickCount, 0); Thread.VolatileWrite(ref firstUnansweredWriteTickCount, 0);
var endpoint = this.bridge.ServerEndPoint.EndPoint; var endpoint = this.bridge.ServerEndPoint.EndPoint;
multiplexer.Trace("Connecting...", physicalName); multiplexer.Trace("Connecting...", physicalName);
this.socketToken = multiplexer.SocketManager.BeginConnect(endpoint, this); this.socketToken = multiplexer.SocketManager.BeginConnect(endpoint, this);
} }
private enum ReadMode : byte private enum ReadMode : byte
{ {
NotSpecified, NotSpecified,
ReadOnly, ReadOnly,
ReadWrite ReadWrite
} }
public PhysicalBridge Bridge { get { return bridge; } } public PhysicalBridge Bridge { get { return bridge; } }
public long LastWriteSecondsAgo public long LastWriteSecondsAgo
{ {
get get
{ {
return unchecked(Environment.TickCount - Thread.VolatileRead(ref lastWriteTickCount)) / 1000; return unchecked(Environment.TickCount - Thread.VolatileRead(ref lastWriteTickCount)) / 1000;
} }
} }
public ConnectionMultiplexer Multiplexer { get { return multiplexer; } } public ConnectionMultiplexer Multiplexer { get { return multiplexer; } }
public long SubscriptionCount { get; set; } public long SubscriptionCount { get; set; }
public bool TransactionActive { get; internal set; } public bool TransactionActive { get; internal set; }
public void Dispose() public void Dispose()
{ {
if (outStream != null) if (outStream != null)
{ {
multiplexer.Trace("Disconnecting...", physicalName); multiplexer.Trace("Disconnecting...", physicalName);
try { outStream.Close(); } catch { } try { outStream.Close(); } catch { }
try { outStream.Dispose(); } catch { } try { outStream.Dispose(); } catch { }
outStream = null; outStream = null;
} }
if (netStream != null) if (netStream != null)
{ {
try { netStream.Close(); } catch { } try { netStream.Close(); } catch { }
try { netStream.Dispose(); } catch { } try { netStream.Dispose(); } catch { }
netStream = null; netStream = null;
} }
if (socketToken.HasValue) if (socketToken.HasValue)
{ {
var socketManager = multiplexer.SocketManager; var socketManager = multiplexer.SocketManager;
if (socketManager != null) socketManager.Shutdown(socketToken); if (socketManager != null) socketManager.Shutdown(socketToken);
socketToken = default(SocketToken); socketToken = default(SocketToken);
multiplexer.Trace("Disconnected", physicalName); multiplexer.Trace("Disconnected", physicalName);
RecordConnectionFailed(ConnectionFailureType.ConnectionDisposed); RecordConnectionFailed(ConnectionFailureType.ConnectionDisposed);
} }
OnCloseEcho(); OnCloseEcho();
} }
public void Flush() public void Flush()
{ {
var tmp = outStream; var tmp = outStream;
if (tmp != null) if (tmp != null)
{ {
tmp.Flush(); tmp.Flush();
Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount); Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount);
} }
} }
public void RecordConnectionFailed(ConnectionFailureType failureType, Exception innerException = null, [CallerMemberName] string origin = null) public void RecordConnectionFailed(ConnectionFailureType failureType, Exception innerException = null, [CallerMemberName] string origin = null)
{ {
IdentifyFailureType(innerException, ref failureType); IdentifyFailureType(innerException, ref failureType);
if (failureType == ConnectionFailureType.InternalFailure) OnInternalError(innerException, origin); if (failureType == ConnectionFailureType.InternalFailure) OnInternalError(innerException, origin);
// stop anything new coming in... // stop anything new coming in...
bridge.Trace("Failed: " + failureType); bridge.Trace("Failed: " + failureType);
bool isCurrent; bool isCurrent;
PhysicalBridge.State oldState; PhysicalBridge.State oldState;
bridge.OnDisconnected(failureType, this, out isCurrent, out oldState); bridge.OnDisconnected(failureType, this, out isCurrent, out oldState);
if (isCurrent && Interlocked.CompareExchange(ref failureReported, 1, 0) == 0) if (isCurrent && Interlocked.CompareExchange(ref failureReported, 1, 0) == 0)
{ {
int now = Environment.TickCount, lastRead = Thread.VolatileRead(ref lastReadTickCount), lastWrite = Thread.VolatileRead(ref lastWriteTickCount), int now = Environment.TickCount, lastRead = Thread.VolatileRead(ref lastReadTickCount), lastWrite = Thread.VolatileRead(ref lastWriteTickCount),
lastBeat = Thread.VolatileRead(ref lastBeatTickCount); lastBeat = Thread.VolatileRead(ref lastBeatTickCount);
int unansweredRead = Thread.VolatileRead(ref firstUnansweredWriteTickCount); int unansweredRead = Thread.VolatileRead(ref firstUnansweredWriteTickCount);
string message = failureType + " on " + Format.ToString(bridge.ServerEndPoint.EndPoint) + "/" + connectionType string message = failureType + " on " + Format.ToString(bridge.ServerEndPoint.EndPoint) + "/" + connectionType
+ ", input-buffer: " + ioBufferBytes + ", outstanding: " + GetSentAwaitingResponseCount() + ", input-buffer: " + ioBufferBytes + ", outstanding: " + GetSentAwaitingResponseCount()
+ ", last-read: " + unchecked(now - lastRead) / 1000 + "s ago, last-write: " + unchecked(now - lastWrite) / 1000 + "s ago" + ", last-read: " + unchecked(now - lastRead) / 1000 + "s ago, last-write: " + unchecked(now - lastWrite) / 1000 + "s ago"
+ ", unanswered-write: " + unchecked(now - unansweredRead) / 1000 + "s ago" + ", unanswered-write: " + unchecked(now - unansweredRead) / 1000 + "s ago"
+ ", keep-alive: " + bridge.ServerEndPoint.WriteEverySeconds + "s, pending: " + ", keep-alive: " + bridge.ServerEndPoint.WriteEverySeconds + "s, pending: "
+ bridge.GetPendingCount() + ", state: " + oldState + ", last-heartbeat: " + (lastBeat == 0 ? "never" : (unchecked(now - lastBeat) / 1000 + "s ago")) + bridge.GetPendingCount() + ", state: " + oldState + ", last-heartbeat: " + (lastBeat == 0 ? "never" : (unchecked(now - lastBeat) / 1000 + "s ago"))
+ (bridge.IsBeating ? " (mid-beat)" : "") + ", last-mbeat: " + multiplexer.LastHeartbeatSecondsAgo + "s ago, global: " + (bridge.IsBeating ? " (mid-beat)" : "") + ", last-mbeat: " + multiplexer.LastHeartbeatSecondsAgo + "s ago, global: "
+ ConnectionMultiplexer.LastGlobalHeartbeatSecondsAgo + "s ago"; + ConnectionMultiplexer.LastGlobalHeartbeatSecondsAgo + "s ago";
var ex = innerException == null var ex = innerException == null
? new RedisConnectionException(failureType, message) ? new RedisConnectionException(failureType, message)
: new RedisConnectionException(failureType, message, innerException); : new RedisConnectionException(failureType, message, innerException);
bridge.OnConnectionFailed(this, failureType, ex); bridge.OnConnectionFailed(this, failureType, ex);
} }
// cleanup // cleanup
lock (outstanding) lock (outstanding)
{ {
bridge.Trace(outstanding.Count != 0, "Failing outstanding messages: " + outstanding.Count); bridge.Trace(outstanding.Count != 0, "Failing outstanding messages: " + outstanding.Count);
while (outstanding.Count != 0) while (outstanding.Count != 0)
{ {
var next = outstanding.Dequeue(); var next = outstanding.Dequeue();
bridge.Trace("Failing: " + next); bridge.Trace("Failing: " + next);
next.Fail(failureType, innerException); next.Fail(failureType, innerException);
bridge.CompleteSyncOrAsync(next); bridge.CompleteSyncOrAsync(next);
} }
} }
// burn the socket // burn the socket
var socketManager = multiplexer.SocketManager; var socketManager = multiplexer.SocketManager;
if (socketManager != null) socketManager.Shutdown(socketToken); if (socketManager != null) socketManager.Shutdown(socketToken);
} }
public override string ToString() public override string ToString()
{ {
return physicalName; return physicalName;
} }
internal static void IdentifyFailureType(Exception exception, ref ConnectionFailureType failureType) internal static void IdentifyFailureType(Exception exception, ref ConnectionFailureType failureType)
{ {
if (exception != null && failureType == ConnectionFailureType.InternalFailure) if (exception != null && failureType == ConnectionFailureType.InternalFailure)
{ {
if (exception is AuthenticationException) failureType = ConnectionFailureType.AuthenticationFailure; if (exception is AuthenticationException) failureType = ConnectionFailureType.AuthenticationFailure;
else if (exception is SocketException || exception is IOException) failureType = ConnectionFailureType.SocketFailure; else if (exception is SocketException || exception is IOException) failureType = ConnectionFailureType.SocketFailure;
else if (exception is EndOfStreamException) failureType = ConnectionFailureType.SocketClosed; else if (exception is EndOfStreamException) failureType = ConnectionFailureType.SocketClosed;
else if (exception is ObjectDisposedException) failureType = ConnectionFailureType.SocketClosed; else if (exception is ObjectDisposedException) failureType = ConnectionFailureType.SocketClosed;
} }
} }
internal void Enqueue(Message next) internal void Enqueue(Message next)
{ {
lock (outstanding) lock (outstanding)
{ {
outstanding.Enqueue(next); outstanding.Enqueue(next);
} }
} }
internal void GetCounters(ConnectionCounters counters) internal void GetCounters(ConnectionCounters counters)
{ {
lock (outstanding) lock (outstanding)
{ {
counters.SentItemsAwaitingResponse = outstanding.Count; counters.SentItemsAwaitingResponse = outstanding.Count;
} }
counters.Subscriptions = SubscriptionCount; counters.Subscriptions = SubscriptionCount;
} }
internal Message GetReadModeCommand(bool isMasterOnly) internal Message GetReadModeCommand(bool isMasterOnly)
{ {
var serverEndpoint = bridge.ServerEndPoint; var serverEndpoint = bridge.ServerEndPoint;
if (serverEndpoint.RequiresReadMode) if (serverEndpoint.RequiresReadMode)
{ {
ReadMode requiredReadMode = isMasterOnly ? ReadMode.ReadWrite : ReadMode.ReadOnly; ReadMode requiredReadMode = isMasterOnly ? ReadMode.ReadWrite : ReadMode.ReadOnly;
if (requiredReadMode != currentReadMode) if (requiredReadMode != currentReadMode)
{ {
currentReadMode = requiredReadMode; currentReadMode = requiredReadMode;
switch (requiredReadMode) switch (requiredReadMode)
{ {
case ReadMode.ReadOnly: return ReusableReadOnlyCommand; case ReadMode.ReadOnly: return ReusableReadOnlyCommand;
case ReadMode.ReadWrite: return ReusableReadWriteCommand; case ReadMode.ReadWrite: return ReusableReadWriteCommand;
} }
} }
} }
else if (currentReadMode == ReadMode.ReadOnly) else if (currentReadMode == ReadMode.ReadOnly)
{ // we don't need it (because we're not a cluster, or not a slave), { // we don't need it (because we're not a cluster, or not a slave),
// but we are in read-only mode; switch to read-write // but we are in read-only mode; switch to read-write
currentReadMode = ReadMode.ReadWrite; currentReadMode = ReadMode.ReadWrite;
return ReusableReadWriteCommand; return ReusableReadWriteCommand;
} }
return null; return null;
} }
internal Message GetSelectDatabaseCommand(int targetDatabase, Message message) internal Message GetSelectDatabaseCommand(int targetDatabase, Message message)
{ {
if (targetDatabase < 0) return null; if (targetDatabase < 0) return null;
if (targetDatabase != currentDatabase) if (targetDatabase != currentDatabase)
{ {
var serverEndpoint = bridge.ServerEndPoint; var serverEndpoint = bridge.ServerEndPoint;
int available = serverEndpoint.Databases; int available = serverEndpoint.Databases;
if (!serverEndpoint.HasDatabases) // only db0 is available on cluster/twemproxy if (!serverEndpoint.HasDatabases) // only db0 is available on cluster/twemproxy
{ {
if (targetDatabase != 0) if (targetDatabase != 0)
{ // should never see this, since the API doesn't allow it; thus not too worried about ExceptionFactory { // should never see this, since the API doesn't allow it; thus not too worried about ExceptionFactory
throw new RedisCommandException("Multiple databases are not supported on this server; cannot switch to database: " + targetDatabase); throw new RedisCommandException("Multiple databases are not supported on this server; cannot switch to database: " + targetDatabase);
} }
return null; return null;
} }
if(message.Command == RedisCommand.SELECT) if(message.Command == RedisCommand.SELECT)
{ {
// this could come from an EVAL/EVALSHA inside a transaction, for example; we'll accept it // this could come from an EVAL/EVALSHA inside a transaction, for example; we'll accept it
bridge.Trace("Switching database: " + targetDatabase); bridge.Trace("Switching database: " + targetDatabase);
currentDatabase = targetDatabase; currentDatabase = targetDatabase;
return null; return null;
} }
if (TransactionActive) if (TransactionActive)
{// should never see this, since the API doesn't allow it; thus not too worried about ExceptionFactory {// should never see this, since the API doesn't allow it; thus not too worried about ExceptionFactory
throw new RedisCommandException("Multiple databases inside a transaction are not currently supported: " + targetDatabase); throw new RedisCommandException("Multiple databases inside a transaction are not currently supported: " + targetDatabase);
} }
if (available != 0 && targetDatabase >= available) // we positively know it is out of range if (available != 0 && targetDatabase >= available) // we positively know it is out of range
{ {
throw ExceptionFactory.DatabaseOutfRange(multiplexer.IncludeDetailInExceptions, targetDatabase, message, serverEndpoint); throw ExceptionFactory.DatabaseOutfRange(multiplexer.IncludeDetailInExceptions, targetDatabase, message, serverEndpoint);
} }
bridge.Trace("Switching database: " + targetDatabase); bridge.Trace("Switching database: " + targetDatabase);
currentDatabase = targetDatabase; currentDatabase = targetDatabase;
return GetSelectDatabaseCommand(targetDatabase); return GetSelectDatabaseCommand(targetDatabase);
} }
return null; return null;
} }
internal static Message GetSelectDatabaseCommand(int targetDatabase) internal static Message GetSelectDatabaseCommand(int targetDatabase)
{ {
return targetDatabase < DefaultRedisDatabaseCount return targetDatabase < DefaultRedisDatabaseCount
? ReusableChangeDatabaseCommands[targetDatabase] // 0-15 by default ? ReusableChangeDatabaseCommands[targetDatabase] // 0-15 by default
: Message.Create(targetDatabase, CommandFlags.FireAndForget, RedisCommand.SELECT); : Message.Create(targetDatabase, CommandFlags.FireAndForget, RedisCommand.SELECT);
} }
internal int GetSentAwaitingResponseCount() internal int GetSentAwaitingResponseCount()
{ {
lock (outstanding) lock (outstanding)
{ {
return outstanding.Count; return outstanding.Count;
} }
} }
internal void GetStormLog(StringBuilder sb) internal void GetStormLog(StringBuilder sb)
{ {
lock (outstanding) lock (outstanding)
{ {
if (outstanding.Count == 0) return; if (outstanding.Count == 0) return;
sb.Append("Sent, awaiting response from server: ").Append(outstanding.Count).AppendLine(); sb.Append("Sent, awaiting response from server: ").Append(outstanding.Count).AppendLine();
int total = 0; int total = 0;
foreach (var item in outstanding) foreach (var item in outstanding)
{ {
if (++total >= 500) break; if (++total >= 500) break;
item.AppendStormLog(sb); item.AppendStormLog(sb);
sb.AppendLine(); sb.AppendLine();
} }
} }
} }
internal void OnHeartbeat() internal void OnHeartbeat()
{ {
Interlocked.Exchange(ref lastBeatTickCount, Environment.TickCount); Interlocked.Exchange(ref lastBeatTickCount, Environment.TickCount);
} }
internal void OnInternalError(Exception exception, [CallerMemberName] string origin = null) internal void OnInternalError(Exception exception, [CallerMemberName] string origin = null)
{ {
multiplexer.OnInternalError(exception, bridge.ServerEndPoint.EndPoint, connectionType, origin); multiplexer.OnInternalError(exception, bridge.ServerEndPoint.EndPoint, connectionType, origin);
} }
internal void SetUnknownDatabase() internal void SetUnknownDatabase()
{ // forces next db-specific command to issue a select { // forces next db-specific command to issue a select
currentDatabase = -1; currentDatabase = -1;
} }
internal void Write(RedisKey key) internal void Write(RedisKey key)
{ {
var val = key.KeyValue; var val = key.KeyValue;
if (val is string) if (val is string)
{ {
WriteUnified(outStream, key.KeyPrefix, (string)val); WriteUnified(outStream, key.KeyPrefix, (string)val);
} }
else else
{ {
WriteUnified(outStream, key.KeyPrefix, (byte[])val); WriteUnified(outStream, key.KeyPrefix, (byte[])val);
} }
} }
internal void Write(RedisChannel channel) internal void Write(RedisChannel channel)
{ {
WriteUnified(outStream, ChannelPrefix, channel.Value); WriteUnified(outStream, ChannelPrefix, channel.Value);
} }
internal void Write(RedisValue value) internal void Write(RedisValue value)
{ {
if (value.IsInteger) if (value.IsInteger)
{ {
WriteUnified(outStream, (long)value); WriteUnified(outStream, (long)value);
} }
else else
{ {
WriteUnified(outStream, (byte[])value); WriteUnified(outStream, (byte[])value);
} }
} }
internal void WriteHeader(RedisCommand command, int arguments) internal void WriteHeader(RedisCommand command, int arguments)
{ {
var commandBytes = multiplexer.CommandMap.GetBytes(command); var commandBytes = multiplexer.CommandMap.GetBytes(command);
if (commandBytes == null) if (commandBytes == null)
{ {
throw ExceptionFactory.CommandDisabled(multiplexer.IncludeDetailInExceptions, command, null, bridge.ServerEndPoint); throw ExceptionFactory.CommandDisabled(multiplexer.IncludeDetailInExceptions, command, null, bridge.ServerEndPoint);
} }
outStream.WriteByte((byte)'*'); outStream.WriteByte((byte)'*');
// remember the time of the first write that still not followed by read // remember the time of the first write that still not followed by read
Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0); Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0);
WriteRaw(outStream, arguments + 1); WriteRaw(outStream, arguments + 1);
WriteUnified(outStream, commandBytes); WriteUnified(outStream, commandBytes);
} }
static void WriteRaw(Stream stream, long value, bool withLengthPrefix = false) static void WriteRaw(Stream stream, long value, bool withLengthPrefix = false)
{ {
if (value >= 0 && value <= 9) if (value >= 0 && value <= 9)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'1'); stream.WriteByte((byte)'1');
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
stream.WriteByte((byte)((int)'0' + (int)value)); stream.WriteByte((byte)((int)'0' + (int)value));
} }
else if (value >= 10 && value < 100) else if (value >= 10 && value < 100)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'2'); stream.WriteByte((byte)'2');
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
stream.WriteByte((byte)((int)'0' + (int)value / 10)); stream.WriteByte((byte)((int)'0' + (int)value / 10));
stream.WriteByte((byte)((int)'0' + (int)value % 10)); stream.WriteByte((byte)((int)'0' + (int)value % 10));
} }
else if (value >= 100 && value < 1000) else if (value >= 100 && value < 1000)
{ {
int v = (int)value; int v = (int)value;
int units = v % 10; int units = v % 10;
v /= 10; v /= 10;
int tens = v % 10, hundreds = v / 10; int tens = v % 10, hundreds = v / 10;
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'3'); stream.WriteByte((byte)'3');
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
stream.WriteByte((byte)((int)'0' + hundreds)); stream.WriteByte((byte)((int)'0' + hundreds));
stream.WriteByte((byte)((int)'0' + tens)); stream.WriteByte((byte)((int)'0' + tens));
stream.WriteByte((byte)((int)'0' + units)); stream.WriteByte((byte)((int)'0' + units));
} }
else if (value < 0 && value >= -9) else if (value < 0 && value >= -9)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'2'); stream.WriteByte((byte)'2');
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
stream.WriteByte((byte)'-'); stream.WriteByte((byte)'-');
stream.WriteByte((byte)((int)'0' - (int)value)); stream.WriteByte((byte)((int)'0' - (int)value));
} }
else if (value <= -10 && value > -100) else if (value <= -10 && value > -100)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'3'); stream.WriteByte((byte)'3');
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
value = -value; value = -value;
stream.WriteByte((byte)'-'); stream.WriteByte((byte)'-');
stream.WriteByte((byte)((int)'0' + (int)value / 10)); stream.WriteByte((byte)((int)'0' + (int)value / 10));
stream.WriteByte((byte)((int)'0' + (int)value % 10)); stream.WriteByte((byte)((int)'0' + (int)value % 10));
} }
else else
{ {
var bytes = Encoding.ASCII.GetBytes(Format.ToString(value)); var bytes = Encoding.ASCII.GetBytes(Format.ToString(value));
if (withLengthPrefix) if (withLengthPrefix)
{ {
WriteRaw(stream, bytes.Length, false); WriteRaw(stream, bytes.Length, false);
} }
stream.Write(bytes, 0, bytes.Length); stream.Write(bytes, 0, bytes.Length);
} }
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
static void WriteUnified(Stream stream, byte[] value) static void WriteUnified(Stream stream, byte[] value)
{ {
stream.WriteByte((byte)'$'); stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); // note that not many things like this... WriteRaw(stream, -1); // note that not many things like this...
} }
else else
{ {
WriteRaw(stream, value.Length); WriteRaw(stream, value.Length);
stream.Write(value, 0, value.Length); stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
} }
internal void WriteAsHex(byte[] value) internal void WriteAsHex(byte[] value)
{ {
var stream = outStream; var stream = outStream;
stream.WriteByte((byte)'$'); stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); WriteRaw(stream, -1);
} else } else
{ {
WriteRaw(stream, value.Length * 2); WriteRaw(stream, value.Length * 2);
for(int i = 0; i < value.Length; i++) for(int i = 0; i < value.Length; i++)
{ {
stream.WriteByte(ToHexNibble(value[i] >> 4)); stream.WriteByte(ToHexNibble(value[i] >> 4));
stream.WriteByte(ToHexNibble(value[i] & 15)); stream.WriteByte(ToHexNibble(value[i] & 15));
} }
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
} }
internal static byte ToHexNibble(int value) internal static byte ToHexNibble(int value)
{ {
return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value);
} }
void WriteUnified(Stream stream, byte[] prefix, string value) void WriteUnified(Stream stream, byte[] prefix, string value)
{ {
stream.WriteByte((byte)'$'); stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); // note that not many things like this... WriteRaw(stream, -1); // note that not many things like this...
} }
else else
{ {
int encodedLength = Encoding.UTF8.GetByteCount(value); int encodedLength = Encoding.UTF8.GetByteCount(value);
if (prefix == null) if (prefix == null)
{ {
WriteRaw(stream, encodedLength); WriteRaw(stream, encodedLength);
WriteRaw(stream, value, encodedLength); WriteRaw(stream, value, encodedLength);
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
else else
{ {
WriteRaw(stream, prefix.Length + encodedLength); WriteRaw(stream, prefix.Length + encodedLength);
stream.Write(prefix, 0, prefix.Length); stream.Write(prefix, 0, prefix.Length);
WriteRaw(stream, value, encodedLength); WriteRaw(stream, value, encodedLength);
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
} }
} }
unsafe void WriteRaw(Stream stream, string value, int encodedLength) unsafe void WriteRaw(Stream stream, string value, int encodedLength)
{ {
if (encodedLength <= ScratchSize) if (encodedLength <= ScratchSize)
{ {
int bytes = Encoding.UTF8.GetBytes(value, 0, value.Length, outScratch, 0); int bytes = Encoding.UTF8.GetBytes(value, 0, value.Length, outScratch, 0);
stream.Write(outScratch, 0, bytes); stream.Write(outScratch, 0, bytes);
} }
else else
{ {
fixed (char* c = value) fixed (char* c = value)
fixed (byte* b = outScratch) fixed (byte* b = outScratch)
{ {
int charsRemaining = value.Length, charOffset = 0, bytesWritten; int charsRemaining = value.Length, charOffset = 0, bytesWritten;
while (charsRemaining > Scratch_CharsPerBlock) while (charsRemaining > Scratch_CharsPerBlock)
{ {
bytesWritten = outEncoder.GetBytes(c + charOffset, Scratch_CharsPerBlock, b, ScratchSize, false); bytesWritten = outEncoder.GetBytes(c + charOffset, Scratch_CharsPerBlock, b, ScratchSize, false);
stream.Write(outScratch, 0, bytesWritten); stream.Write(outScratch, 0, bytesWritten);
charOffset += Scratch_CharsPerBlock; charOffset += Scratch_CharsPerBlock;
charsRemaining -= Scratch_CharsPerBlock; charsRemaining -= Scratch_CharsPerBlock;
} }
bytesWritten = outEncoder.GetBytes(c + charOffset, charsRemaining, b, ScratchSize, true); bytesWritten = outEncoder.GetBytes(c + charOffset, charsRemaining, b, ScratchSize, true);
if (bytesWritten != 0) stream.Write(outScratch, 0, bytesWritten); if (bytesWritten != 0) stream.Write(outScratch, 0, bytesWritten);
} }
} }
} }
const int ScratchSize = 512; const int ScratchSize = 512;
static readonly int Scratch_CharsPerBlock = ScratchSize / Encoding.UTF8.GetMaxByteCount(1); static readonly int Scratch_CharsPerBlock = ScratchSize / Encoding.UTF8.GetMaxByteCount(1);
private readonly byte[] outScratch = new byte[ScratchSize]; private readonly byte[] outScratch = new byte[ScratchSize];
private readonly Encoder outEncoder = Encoding.UTF8.GetEncoder(); private readonly Encoder outEncoder = Encoding.UTF8.GetEncoder();
static void WriteUnified(Stream stream, byte[] prefix, byte[] value) static void WriteUnified(Stream stream, byte[] prefix, byte[] value)
{ {
stream.WriteByte((byte)'$'); stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); // note that not many things like this... WriteRaw(stream, -1); // note that not many things like this...
} }
else if (prefix == null) else if (prefix == null)
{ {
WriteRaw(stream, value.Length); WriteRaw(stream, value.Length);
stream.Write(value, 0, value.Length); stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
else else
{ {
WriteRaw(stream, prefix.Length + value.Length); WriteRaw(stream, prefix.Length + value.Length);
stream.Write(prefix, 0, prefix.Length); stream.Write(prefix, 0, prefix.Length);
stream.Write(value, 0, value.Length); stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2); stream.Write(Crlf, 0, 2);
} }
} }
static void WriteUnified(Stream stream, long value) static void WriteUnified(Stream stream, long value)
{ {
// note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings.
// (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n"
stream.WriteByte((byte)'$'); stream.WriteByte((byte)'$');
WriteRaw(stream, value, withLengthPrefix: true); WriteRaw(stream, value, withLengthPrefix: true);
} }
void BeginReading() void BeginReading()
{ {
bool keepReading; bool keepReading;
try try
{ {
do do
{ {
keepReading = false; keepReading = false;
int space = EnsureSpaceAndComputeBytesToRead(); int space = EnsureSpaceAndComputeBytesToRead();
multiplexer.Trace("Beginning async read...", physicalName); multiplexer.Trace("Beginning async read...", physicalName);
var result = netStream.BeginRead(ioBuffer, ioBufferBytes, space, endRead, this); var result = netStream.BeginRead(ioBuffer, ioBufferBytes, space, endRead, this);
if (result.CompletedSynchronously) if (result.CompletedSynchronously)
{ {
multiplexer.Trace("Completed synchronously: processing immediately", physicalName); multiplexer.Trace("Completed synchronously: processing immediately", physicalName);
keepReading = EndReading(result); keepReading = EndReading(result);
} }
} while (keepReading); } while (keepReading);
} }
catch(System.IO.IOException ex) catch(System.IO.IOException ex)
{ {
multiplexer.Trace("Could not connect: " + ex.Message, physicalName); multiplexer.Trace("Could not connect: " + ex.Message, physicalName);
} }
} }
int haveReader; int haveReader;
internal int GetAvailableInboundBytes(out int activeReaders) internal int GetAvailableInboundBytes(out int activeReaders)
{ {
activeReaders = Interlocked.CompareExchange(ref haveReader, 0, 0); activeReaders = Interlocked.CompareExchange(ref haveReader, 0, 0);
return this.socketToken.Available; return this.socketToken.Available;
} }
static LocalCertificateSelectionCallback GetAmbientCertificateCallback() static LocalCertificateSelectionCallback GetAmbientCertificateCallback()
...@@ -646,357 +646,357 @@ static LocalCertificateSelectionCallback GetAmbientCertificateCallback() ...@@ -646,357 +646,357 @@ static LocalCertificateSelectionCallback GetAmbientCertificateCallback()
} catch } catch
{ } { }
return null; return null;
} }
SocketMode ISocketCallback.Connected(Stream stream) SocketMode ISocketCallback.Connected(Stream stream)
{ {
try try
{ {
var socketMode = SocketManager.DefaultSocketMode; var socketMode = SocketManager.DefaultSocketMode;
// disallow connection in some cases // disallow connection in some cases
OnDebugAbort(); OnDebugAbort();
// the order is important here: // the order is important here:
// [network]<==[ssl]<==[logging]<==[buffered] // [network]<==[ssl]<==[logging]<==[buffered]
var config = multiplexer.RawConfig; var config = multiplexer.RawConfig;
if(config.Ssl) if(config.Ssl)
{ {
var host = config.SslHost; var host = config.SslHost;
if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(bridge.ServerEndPoint.EndPoint); if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(bridge.ServerEndPoint.EndPoint);
var ssl = new SslStream(stream, false, config.CertificateValidationCallback, var ssl = new SslStream(stream, false, config.CertificateValidationCallback,
config.CertificateSelectionCallback ?? GetAmbientCertificateCallback() config.CertificateSelectionCallback ?? GetAmbientCertificateCallback()
#if !__MonoCS__ #if !__MonoCS__
, EncryptionPolicy.RequireEncryption , EncryptionPolicy.RequireEncryption
#endif #endif
); );
ssl.AuthenticateAsClient(host); ssl.AuthenticateAsClient(host);
if (!ssl.IsEncrypted) if (!ssl.IsEncrypted)
{ {
RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure); RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure);
multiplexer.Trace("Encryption failure"); multiplexer.Trace("Encryption failure");
return SocketMode.Abort; return SocketMode.Abort;
} }
stream = ssl; stream = ssl;
socketMode = SocketMode.Async; socketMode = SocketMode.Async;
} }
OnWrapForLogging(ref stream, physicalName); OnWrapForLogging(ref stream, physicalName);
int bufferSize = config.WriteBuffer; int bufferSize = config.WriteBuffer;
this.netStream = stream; this.netStream = stream;
this.outStream = bufferSize <= 0 ? stream : new BufferedStream(stream, bufferSize); this.outStream = bufferSize <= 0 ? stream : new BufferedStream(stream, bufferSize);
multiplexer.Trace("Connected", physicalName); multiplexer.Trace("Connected", physicalName);
bridge.OnConnected(this); bridge.OnConnected(this);
return socketMode; return socketMode;
} }
catch (Exception ex) catch (Exception ex)
{ {
RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); // includes a bridge.OnDisconnected RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); // includes a bridge.OnDisconnected
multiplexer.Trace("Could not connect: " + ex.Message, physicalName); multiplexer.Trace("Could not connect: " + ex.Message, physicalName);
return SocketMode.Abort; return SocketMode.Abort;
} }
} }
private bool EndReading(IAsyncResult result) private bool EndReading(IAsyncResult result)
{ {
try try
{ {
var tmp = netStream; var tmp = netStream;
int bytesRead = tmp == null ? 0 : tmp.EndRead(result); int bytesRead = tmp == null ? 0 : tmp.EndRead(result);
return ProcessReadBytes(bytesRead); return ProcessReadBytes(bytesRead);
} }
catch (Exception ex) catch (Exception ex)
{ {
RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
return false; return false;
} }
} }
int EnsureSpaceAndComputeBytesToRead() int EnsureSpaceAndComputeBytesToRead()
{ {
int space = ioBuffer.Length - ioBufferBytes; int space = ioBuffer.Length - ioBufferBytes;
if (space == 0) if (space == 0)
{ {
Array.Resize(ref ioBuffer, ioBuffer.Length * 2); Array.Resize(ref ioBuffer, ioBuffer.Length * 2);
space = ioBuffer.Length - ioBufferBytes; space = ioBuffer.Length - ioBufferBytes;
} }
return space; return space;
} }
void ISocketCallback.Error() void ISocketCallback.Error()
{ {
RecordConnectionFailed(ConnectionFailureType.SocketFailure); RecordConnectionFailed(ConnectionFailureType.SocketFailure);
} }
void MatchResult(RawResult result) void MatchResult(RawResult result)
{ {
// check to see if it could be an out-of-band pubsub message // check to see if it could be an out-of-band pubsub message
if (connectionType == ConnectionType.Subscription && result.Type == ResultType.MultiBulk) if (connectionType == ConnectionType.Subscription && result.Type == ResultType.MultiBulk)
{ // out of band message does not match to a queued message { // out of band message does not match to a queued message
var items = result.GetItems(); var items = result.GetItems();
if (items.Length >= 3 && items[0].IsEqual(message)) if (items.Length >= 3 && items[0].IsEqual(message))
{ {
// special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry) // special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry)
var configChanged = multiplexer.ConfigurationChangedChannel; var configChanged = multiplexer.ConfigurationChangedChannel;
if (configChanged != null && items[1].IsEqual(configChanged)) if (configChanged != null && items[1].IsEqual(configChanged))
{ {
EndPoint blame = null; EndPoint blame = null;
try try
{ {
if (!items[2].IsEqual(RedisLiterals.ByteWildcard)) if (!items[2].IsEqual(RedisLiterals.ByteWildcard))
{ {
blame = Format.TryParseEndPoint(items[2].GetString()); blame = Format.TryParseEndPoint(items[2].GetString());
} }
} }
catch { /* no biggie */ } catch { /* no biggie */ }
multiplexer.Trace("Configuration changed: " + Format.ToString(blame), physicalName); multiplexer.Trace("Configuration changed: " + Format.ToString(blame), physicalName);
multiplexer.ReconfigureIfNeeded(blame, true, "broadcast"); multiplexer.ReconfigureIfNeeded(blame, true, "broadcast");
} }
// invoke the handlers // invoke the handlers
var channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal); var channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal);
multiplexer.Trace("MESSAGE: " + channel, physicalName); multiplexer.Trace("MESSAGE: " + channel, physicalName);
if (!channel.IsNull) if (!channel.IsNull)
{ {
multiplexer.OnMessage(channel, channel, items[2].AsRedisValue()); multiplexer.OnMessage(channel, channel, items[2].AsRedisValue());
} }
return; // AND STOP PROCESSING! return; // AND STOP PROCESSING!
} }
else if (items.Length >= 4 && items[0].IsEqual(pmessage)) else if (items.Length >= 4 && items[0].IsEqual(pmessage))
{ {
var channel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal); var channel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal);
multiplexer.Trace("PMESSAGE: " + channel, physicalName); multiplexer.Trace("PMESSAGE: " + channel, physicalName);
if (!channel.IsNull) if (!channel.IsNull)
{ {
var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern); var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern);
multiplexer.OnMessage(sub, channel, items[3].AsRedisValue()); multiplexer.OnMessage(sub, channel, items[3].AsRedisValue());
} }
return; // AND STOP PROCESSING! return; // AND STOP PROCESSING!
} }
// if it didn't look like "[p]message", then we still need to process the pending queue // if it didn't look like "[p]message", then we still need to process the pending queue
} }
multiplexer.Trace("Matching result...", physicalName); multiplexer.Trace("Matching result...", physicalName);
Message msg; Message msg;
lock (outstanding) lock (outstanding)
{ {
multiplexer.Trace(outstanding.Count == 0, "Nothing to respond to!", physicalName); multiplexer.Trace(outstanding.Count == 0, "Nothing to respond to!", physicalName);
msg = outstanding.Dequeue(); msg = outstanding.Dequeue();
} }
multiplexer.Trace("Response to: " + msg.ToString(), physicalName); multiplexer.Trace("Response to: " + msg.ToString(), physicalName);
if (msg.ComputeResult(this, result)) if (msg.ComputeResult(this, result))
{ {
bridge.CompleteSyncOrAsync(msg); bridge.CompleteSyncOrAsync(msg);
} }
} }
partial void OnCloseEcho(); partial void OnCloseEcho();
partial void OnCreateEcho(); partial void OnCreateEcho();
partial void OnDebugAbort(); partial void OnDebugAbort();
void ISocketCallback.OnHeartbeat() void ISocketCallback.OnHeartbeat()
{ {
try try
{ {
bridge.OnHeartbeat(true); // all the fun code is here bridge.OnHeartbeat(true); // all the fun code is here
} }
catch (Exception ex) catch (Exception ex)
{ {
OnInternalError(ex); OnInternalError(ex);
} }
} }
partial void OnWrapForLogging(ref Stream stream, string name); partial void OnWrapForLogging(ref Stream stream, string name);
private int ProcessBuffer(byte[] underlying, ref int offset, ref int count) private int ProcessBuffer(byte[] underlying, ref int offset, ref int count)
{ {
int messageCount = 0; int messageCount = 0;
RawResult result; RawResult result;
do do
{ {
int tmpOffset = offset, tmpCount = count; int tmpOffset = offset, tmpCount = count;
// we want TryParseResult to be able to mess with these without consequence // we want TryParseResult to be able to mess with these without consequence
result = TryParseResult(underlying, ref tmpOffset, ref tmpCount); result = TryParseResult(underlying, ref tmpOffset, ref tmpCount);
if (result.HasValue) if (result.HasValue)
{ {
messageCount++; messageCount++;
// entire message: update the external counters // entire message: update the external counters
offset = tmpOffset; offset = tmpOffset;
count = tmpCount; count = tmpCount;
multiplexer.Trace(result.ToString(), physicalName); multiplexer.Trace(result.ToString(), physicalName);
MatchResult(result); MatchResult(result);
} }
} while (result.HasValue); } while (result.HasValue);
return messageCount; return messageCount;
} }
private bool ProcessReadBytes(int bytesRead) private bool ProcessReadBytes(int bytesRead)
{ {
if (bytesRead <= 0) if (bytesRead <= 0)
{ {
multiplexer.Trace("EOF", physicalName); multiplexer.Trace("EOF", physicalName);
RecordConnectionFailed(ConnectionFailureType.SocketClosed); RecordConnectionFailed(ConnectionFailureType.SocketClosed);
return false; return false;
} }
Interlocked.Exchange(ref lastReadTickCount, Environment.TickCount); Interlocked.Exchange(ref lastReadTickCount, Environment.TickCount);
// reset unanswered write timestamp // reset unanswered write timestamp
Thread.VolatileWrite(ref firstUnansweredWriteTickCount, 0); Thread.VolatileWrite(ref firstUnansweredWriteTickCount, 0);
ioBufferBytes += bytesRead; ioBufferBytes += bytesRead;
multiplexer.Trace("More bytes available: " + bytesRead + " (" + ioBufferBytes + ")", physicalName); multiplexer.Trace("More bytes available: " + bytesRead + " (" + ioBufferBytes + ")", physicalName);
int offset = 0, count = ioBufferBytes; int offset = 0, count = ioBufferBytes;
int handled = ProcessBuffer(ioBuffer, ref offset, ref count); int handled = ProcessBuffer(ioBuffer, ref offset, ref count);
multiplexer.Trace("Processed: " + handled, physicalName); multiplexer.Trace("Processed: " + handled, physicalName);
if (handled != 0) if (handled != 0)
{ {
// read stuff // read stuff
if (count != 0) if (count != 0)
{ {
multiplexer.Trace("Copying remaining bytes: " + count, physicalName); multiplexer.Trace("Copying remaining bytes: " + count, physicalName);
// if anything was left over, we need to copy it to // if anything was left over, we need to copy it to
// the start of the buffer so it can be used next time // the start of the buffer so it can be used next time
Buffer.BlockCopy(ioBuffer, offset, ioBuffer, 0, count); Buffer.BlockCopy(ioBuffer, offset, ioBuffer, 0, count);
} }
ioBufferBytes = count; ioBufferBytes = count;
} }
return true; return true;
} }
void ISocketCallback.Read() void ISocketCallback.Read()
{ {
Interlocked.Increment(ref haveReader); Interlocked.Increment(ref haveReader);
try try
{ {
do do
{ {
int space = EnsureSpaceAndComputeBytesToRead(); int space = EnsureSpaceAndComputeBytesToRead();
var tmp = netStream; var tmp = netStream;
int bytesRead = tmp == null ? 0 : tmp.Read(ioBuffer, ioBufferBytes, space); int bytesRead = tmp == null ? 0 : tmp.Read(ioBuffer, ioBufferBytes, space);
if (!ProcessReadBytes(bytesRead)) return; // EOF if (!ProcessReadBytes(bytesRead)) return; // EOF
} while (socketToken.Available != 0); } while (socketToken.Available != 0);
multiplexer.Trace("Buffer exhausted", physicalName); multiplexer.Trace("Buffer exhausted", physicalName);
// ^^^ note that the socket manager will call us again when there is something to do // ^^^ note that the socket manager will call us again when there is something to do
} }
catch (Exception ex) catch (Exception ex)
{ {
RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
}finally }finally
{ {
Interlocked.Decrement(ref haveReader); Interlocked.Decrement(ref haveReader);
} }
} }
private RawResult ReadArray(byte[] buffer, ref int offset, ref int count) private RawResult ReadArray(byte[] buffer, ref int offset, ref int count)
{ {
var itemCount = ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count); var itemCount = ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count);
if (itemCount.HasValue) if (itemCount.HasValue)
{ {
long i64; long i64;
if (!itemCount.TryGetInt64(out i64)) throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid array length", bridge.ServerEndPoint); if (!itemCount.TryGetInt64(out i64)) throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid array length", bridge.ServerEndPoint);
int itemCountActual = checked((int)i64); int itemCountActual = checked((int)i64);
if (itemCountActual <= 0) return RawResult.EmptyArray; if (itemCountActual <= 0) return RawResult.EmptyArray;
var arr = new RawResult[itemCountActual]; var arr = new RawResult[itemCountActual];
for (int i = 0; i < itemCountActual; i++) for (int i = 0; i < itemCountActual; i++)
{ {
if (!(arr[i] = TryParseResult(buffer, ref offset, ref count)).HasValue) if (!(arr[i] = TryParseResult(buffer, ref offset, ref count)).HasValue)
return RawResult.Nil; return RawResult.Nil;
} }
return new RawResult(arr); return new RawResult(arr);
} }
return RawResult.Nil; return RawResult.Nil;
} }
private RawResult ReadBulkString(byte[] buffer, ref int offset, ref int count) private RawResult ReadBulkString(byte[] buffer, ref int offset, ref int count)
{ {
var prefix = ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count); var prefix = ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count);
if (prefix.HasValue) if (prefix.HasValue)
{ {
long i64; long i64;
if (!prefix.TryGetInt64(out i64)) throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string length", bridge.ServerEndPoint); if (!prefix.TryGetInt64(out i64)) throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string length", bridge.ServerEndPoint);
int bodySize = checked((int)i64); int bodySize = checked((int)i64);
if (bodySize < 0) if (bodySize < 0)
{ {
return new RawResult(ResultType.BulkString, null, 0, 0); return new RawResult(ResultType.BulkString, null, 0, 0);
} }
else if (count >= bodySize + 2) else if (count >= bodySize + 2)
{ {
if (buffer[offset + bodySize] != '\r' || buffer[offset + bodySize + 1] != '\n') if (buffer[offset + bodySize] != '\r' || buffer[offset + bodySize + 1] != '\n')
{ {
throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string terminator", bridge.ServerEndPoint); throw ExceptionFactory.ConnectionFailure(multiplexer.IncludeDetailInExceptions, ConnectionFailureType.ProtocolFailure, "Invalid bulk string terminator", bridge.ServerEndPoint);
} }
var result = new RawResult(ResultType.BulkString, buffer, offset, bodySize); var result = new RawResult(ResultType.BulkString, buffer, offset, bodySize);
offset += bodySize + 2; offset += bodySize + 2;
count -= bodySize + 2; count -= bodySize + 2;
return result; return result;
} }
} }
return RawResult.Nil; return RawResult.Nil;
} }
private RawResult ReadLineTerminatedString(ResultType type, byte[] buffer, ref int offset, ref int count) private RawResult ReadLineTerminatedString(ResultType type, byte[] buffer, ref int offset, ref int count)
{ {
int max = offset + count - 2; int max = offset + count - 2;
for (int i = offset; i < max; i++) for (int i = offset; i < max; i++)
{ {
if (buffer[i + 1] == '\r' && buffer[i + 2] == '\n') if (buffer[i + 1] == '\r' && buffer[i + 2] == '\n')
{ {
int len = i - offset + 1; int len = i - offset + 1;
var result = new RawResult(type, buffer, offset, len); var result = new RawResult(type, buffer, offset, len);
count -= (len + 2); count -= (len + 2);
offset += (len + 2); offset += (len + 2);
return result; return result;
} }
} }
return RawResult.Nil; return RawResult.Nil;
} }
void ISocketCallback.StartReading() void ISocketCallback.StartReading()
{ {
BeginReading(); BeginReading();
} }
RawResult TryParseResult(byte[] buffer, ref int offset, ref int count) RawResult TryParseResult(byte[] buffer, ref int offset, ref int count)
{ {
if(count == 0) return RawResult.Nil; if(count == 0) return RawResult.Nil;
char resultType = (char)buffer[offset++]; char resultType = (char)buffer[offset++];
count--; count--;
switch(resultType) switch(resultType)
{ {
case '+': // simple string case '+': // simple string
return ReadLineTerminatedString(ResultType.SimpleString, buffer, ref offset, ref count); return ReadLineTerminatedString(ResultType.SimpleString, buffer, ref offset, ref count);
case '-': // error case '-': // error
return ReadLineTerminatedString(ResultType.Error, buffer, ref offset, ref count); return ReadLineTerminatedString(ResultType.Error, buffer, ref offset, ref count);
case ':': // integer case ':': // integer
return ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count); return ReadLineTerminatedString(ResultType.Integer, buffer, ref offset, ref count);
case '$': // bulk string case '$': // bulk string
return ReadBulkString(buffer, ref offset, ref count); return ReadBulkString(buffer, ref offset, ref count);
case '*': // array case '*': // array
return ReadArray(buffer, ref offset, ref count); return ReadArray(buffer, ref offset, ref count);
default: default:
throw new InvalidOperationException("Unexpected response prefix: " + (char)resultType); throw new InvalidOperationException("Unexpected response prefix: " + (char)resultType);
} }
} }
partial void DebugEmulateStaleConnection(ref int firstUnansweredWrite); partial void DebugEmulateStaleConnection(ref int firstUnansweredWrite);
public void CheckForStaleConnection() public void CheckForStaleConnection()
{ {
int firstUnansweredWrite; int firstUnansweredWrite;
firstUnansweredWrite = Thread.VolatileRead(ref firstUnansweredWriteTickCount); firstUnansweredWrite = Thread.VolatileRead(ref firstUnansweredWriteTickCount);
DebugEmulateStaleConnection(ref firstUnansweredWrite); DebugEmulateStaleConnection(ref firstUnansweredWrite);
int now = Environment.TickCount; int now = Environment.TickCount;
if (firstUnansweredWrite != 0 && (now - firstUnansweredWrite) > this.multiplexer.RawConfig.SyncTimeout) if (firstUnansweredWrite != 0 && (now - firstUnansweredWrite) > this.multiplexer.RawConfig.SyncTimeout)
{ {
this.RecordConnectionFailed(ConnectionFailureType.SocketFailure, origin: "CheckForStaleConnection"); this.RecordConnectionFailed(ConnectionFailureType.SocketFailure, origin: "CheckForStaleConnection");
} }
} }
} }
} }
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -373,6 +373,16 @@ public Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandF ...@@ -373,6 +373,16 @@ public Task<byte[]> ScriptLoadAsync(string script, CommandFlags flags = CommandF
{ {
var msg = new RedisDatabase.ScriptLoadMessage(flags, script); var msg = new RedisDatabase.ScriptLoadMessage(flags, script);
return ExecuteAsync(msg, ResultProcessor.ScriptLoad); return ExecuteAsync(msg, ResultProcessor.ScriptLoad);
}
public LoadedLuaScript ScriptLoad(LuaScript script, CommandFlags flags = CommandFlags.None)
{
return script.Load(this, flags);
}
public Task<LoadedLuaScript> ScriptLoadAsync(LuaScript script, CommandFlags flags = CommandFlags.None)
{
return script.LoadAsync(this, flags);
} }
public void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None) public void Shutdown(ShutdownMode shutdownMode = ShutdownMode.Default, CommandFlags flags = CommandFlags.None)
...@@ -563,12 +573,35 @@ internal override RedisFeatures GetFeatures(int db, RedisKey key, CommandFlags f ...@@ -563,12 +573,35 @@ internal override RedisFeatures GetFeatures(int db, RedisKey key, CommandFlags f
public void SlaveOf(EndPoint endpoint, CommandFlags flags = CommandFlags.None) public void SlaveOf(EndPoint endpoint, CommandFlags flags = CommandFlags.None)
{ {
var msg = CreateSlaveOfMessage(endpoint, flags);
if (endpoint == server.EndPoint) if (endpoint == server.EndPoint)
{ {
throw new ArgumentException("Cannot slave to self"); throw new ArgumentException("Cannot slave to self");
} }
ExecuteSync(msg, ResultProcessor.DemandOK); // prepare the actual slaveof message (not sent yet)
var slaveofMsg = CreateSlaveOfMessage(endpoint, flags);
var configuration = this.multiplexer.RawConfig;
// attempt to cease having an opinion on the master; will resume that when replication completes
// (note that this may fail; we aren't depending on it)
if (!string.IsNullOrWhiteSpace(configuration.TieBreaker)
&& this.multiplexer.CommandMap.IsAvailable(RedisCommand.DEL))
{
var del = Message.Create(0, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.DEL, (RedisKey)configuration.TieBreaker);
del.SetInternalCall();
server.QueueDirectFireAndForget(del, ResultProcessor.Boolean);
}
ExecuteSync(slaveofMsg, ResultProcessor.DemandOK);
// attempt to broadcast a reconfigure message to anybody listening to this server
var channel = this.multiplexer.ConfigurationChangedChannel;
if (channel != null && this.multiplexer.CommandMap.IsAvailable(RedisCommand.PUBLISH))
{
var pub = Message.Create(-1, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.PUBLISH, (RedisValue)channel, RedisLiterals.Wildcard);
pub.SetInternalCall();
server.QueueDirectFireAndForget(pub, ResultProcessor.Int64);
}
} }
public Task SlaveOfAsync(EndPoint endpoint, CommandFlags flags = CommandFlags.None) public Task SlaveOfAsync(EndPoint endpoint, CommandFlags flags = CommandFlags.None)
......
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Security.Cryptography;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
class ScriptParameterMapper
{
public struct ScriptParameters
{
public RedisKey[] Keys;
public RedisValue[] Arguments;
public static readonly ConstructorInfo Cons = typeof(ScriptParameters).GetConstructor(new[] { typeof(RedisKey[]), typeof(RedisValue[]) });
public ScriptParameters(RedisKey[] keys, RedisValue[] args)
{
Keys = keys;
Arguments = args;
}
}
static readonly Regex ParameterExtractor = new Regex(@"@(?<paramName> ([a-z]|_) ([a-z]|_|\d)*)", RegexOptions.Compiled | RegexOptions.IgnoreCase | RegexOptions.IgnorePatternWhitespace);
static string[] ExtractParameters(string script)
{
var ps = ParameterExtractor.Matches(script);
if (ps.Count == 0) return null;
var ret = new HashSet<string>();
for (var i = 0; i < ps.Count; i++)
{
var c = ps[i];
var ix = c.Index - 1;
if (ix >= 0)
{
var prevChar = script[ix];
// don't consider this a parameter if it's in the middle of word (ie. if it's preceeded by a letter)
if (char.IsLetterOrDigit(prevChar) || prevChar == '_') continue;
// this is an escape, ignore it
if (prevChar == '@') continue;
}
var n = c.Groups["paramName"].Value;
if (!ret.Contains(n)) ret.Add(n);
}
return ret.ToArray();
}
static string MakeOrdinalScriptWithoutKeys(string rawScript, string[] args)
{
var ps = ParameterExtractor.Matches(rawScript);
if (ps.Count == 0) return rawScript;
var ret = new StringBuilder();
var upTo = 0;
for (var i = 0; i < ps.Count; i++)
{
var capture = ps[i];
var name = capture.Groups["paramName"].Value;
var ix = capture.Index;
ret.Append(rawScript.Substring(upTo, ix - upTo));
var argIx = Array.IndexOf(args, name);
if (argIx != -1)
{
ret.Append("ARGV[");
ret.Append(argIx + 1);
ret.Append("]");
}
else
{
var isEscape = false;
var prevIx = capture.Index - 1;
if (prevIx >= 0)
{
var prevChar = rawScript[prevIx];
isEscape = prevChar == '@';
}
if (isEscape)
{
// strip the @ off, so just the one triggering the escape exists
ret.Append(capture.Groups["paramName"].Value);
}
else
{
ret.Append(capture.Value);
}
}
upTo = capture.Index + capture.Length;
}
ret.Append(rawScript.Substring(upTo, rawScript.Length - upTo));
return ret.ToString();
}
static void LoadMember(ILGenerator il, MemberInfo member)
{
// stack starts:
// T(*?)
var asField = member as FieldInfo;
if (asField != null)
{
il.Emit(OpCodes.Ldfld, asField); // typeof(member)
return;
}
var asProp = member as PropertyInfo;
if (asProp != null)
{
var getter = asProp.GetGetMethod();
if (getter.IsVirtual)
{
il.Emit(OpCodes.Callvirt, getter); // typeof(member)
}
else
{
il.Emit(OpCodes.Call, getter); // typeof(member)
}
return;
}
throw new Exception("Should't be possible");
}
static readonly MethodInfo RedisValue_FromInt = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(int) });
static readonly MethodInfo RedisValue_FromNullableInt = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(int?) });
static readonly MethodInfo RedisValue_FromLong = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(long) });
static readonly MethodInfo RedisValue_FromNullableLong = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(long?) });
static readonly MethodInfo RedisValue_FromDouble= typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(double) });
static readonly MethodInfo RedisValue_FromNullableDouble = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(double?) });
static readonly MethodInfo RedisValue_FromString = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(string) });
static readonly MethodInfo RedisValue_FromByteArray = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(byte[]) });
static readonly MethodInfo RedisValue_FromBool = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(bool) });
static readonly MethodInfo RedisValue_FromNullableBool = typeof(RedisValue).GetMethod("op_Implicit", new[] { typeof(bool?) });
static readonly MethodInfo RedisKey_AsRedisValue = typeof(RedisKey).GetMethod("AsRedisValue", BindingFlags.NonPublic | BindingFlags.Instance);
static void ConvertToRedisValue(MemberInfo member, ILGenerator il, LocalBuilder needsPrefixBool, ref LocalBuilder redisKeyLoc)
{
// stack starts:
// typeof(member)
var t = member is FieldInfo ? ((FieldInfo)member).FieldType : ((PropertyInfo)member).PropertyType;
if (t == typeof(RedisValue))
{
// They've already converted for us, don't do anything
return;
}
if (t == typeof(RedisKey))
{
redisKeyLoc = redisKeyLoc ?? il.DeclareLocal(typeof(RedisKey));
PrefixIfNeeded(il, needsPrefixBool, ref redisKeyLoc); // RedisKey
il.Emit(OpCodes.Stloc, redisKeyLoc); // --empty--
il.Emit(OpCodes.Ldloca, redisKeyLoc); // RedisKey*
il.Emit(OpCodes.Call, RedisKey_AsRedisValue); // RedisValue
return;
}
MethodInfo convertOp = null;
if (t == typeof(int)) convertOp = RedisValue_FromInt;
if (t == typeof(int?)) convertOp = RedisValue_FromNullableInt;
if (t == typeof(long)) convertOp = RedisValue_FromLong;
if (t == typeof(long?)) convertOp = RedisValue_FromNullableLong;
if (t == typeof(double)) convertOp = RedisValue_FromDouble;
if (t == typeof(double?)) convertOp = RedisValue_FromNullableDouble;
if (t == typeof(string)) convertOp = RedisValue_FromString;
if (t == typeof(byte[])) convertOp = RedisValue_FromByteArray;
if (t == typeof(bool)) convertOp = RedisValue_FromBool;
if (t == typeof(bool?)) convertOp = RedisValue_FromNullableBool;
il.Emit(OpCodes.Call, convertOp);
// stack ends:
// RedisValue
}
/// <summary>
/// Turns a script with @namedParameters into a LuaScript that can be executed
/// against a given IDatabase(Async) object
/// </summary>
public static LuaScript PrepareScript(string script)
{
var ps = ExtractParameters(script);
var ordinalScript = MakeOrdinalScriptWithoutKeys(script, ps);
return new LuaScript(script, ordinalScript, ps);
}
static readonly HashSet<Type> ConvertableTypes =
new HashSet<Type> {
typeof(int),
typeof(int?),
typeof(long),
typeof(long?),
typeof(double),
typeof(double?),
typeof(string),
typeof(byte[]),
typeof(bool),
typeof(bool?),
typeof(RedisKey),
typeof(RedisValue)
};
/// <summary>
/// Determines whether or not the given type can be used to provide parameters for the given LuaScript.
/// </summary>
public static bool IsValidParameterHash(Type t, LuaScript script, out string missingMember, out string badTypeMember)
{
for (var i = 0; i < script.Arguments.Length; i++)
{
var argName = script.Arguments[i];
var member = t.GetMember(argName).Where(m => m is PropertyInfo || m is FieldInfo).SingleOrDefault();
if (member == null)
{
missingMember = argName;
badTypeMember = null;
return false;
}
var memberType = member is FieldInfo ? ((FieldInfo)member).FieldType : ((PropertyInfo)member).PropertyType;
if(!ConvertableTypes.Contains(memberType)){
missingMember = null;
badTypeMember = argName;
return false;
}
}
missingMember = badTypeMember = null;
return true;
}
static void PrefixIfNeeded(ILGenerator il, LocalBuilder needsPrefixBool, ref LocalBuilder redisKeyLoc)
{
// top of stack is
// RedisKey
var getVal = typeof(RedisKey?).GetProperty("Value").GetGetMethod();
var prepend = typeof(RedisKey).GetMethod("Prepend");
var doNothing = il.DefineLabel();
redisKeyLoc = redisKeyLoc ?? il.DeclareLocal(typeof(RedisKey));
il.Emit(OpCodes.Ldloc, needsPrefixBool); // RedisKey bool
il.Emit(OpCodes.Brfalse, doNothing); // RedisKey
il.Emit(OpCodes.Stloc, redisKeyLoc); // --empty--
il.Emit(OpCodes.Ldloca, redisKeyLoc); // RedisKey*
il.Emit(OpCodes.Ldarga_S, 1); // RedisKey* RedisKey?*
il.Emit(OpCodes.Call, getVal); // RedisKey* RedisKey
il.Emit(OpCodes.Call, prepend); // RedisKey
il.MarkLabel(doNothing); // RedisKey
}
/// <summary>
/// Creates a Func that extracts parameters from the given type for use by a LuaScript.
///
/// Members that are RedisKey's get extracted to be passed in as keys to redis; all members that
/// appear in the script get extracted as RedisValue arguments to be sent up as args.
///
/// We send all values as arguments so we don't have to prepare the same script for different parameter
/// types.
///
/// The created Func takes a RedisKey, which will be prefixed to all keys (and arguments of type RedisKey) for
/// keyspace isolation.
/// </summary>
public static Func<object, RedisKey?, ScriptParameters> GetParameterExtractor(Type t, LuaScript script)
{
string ignored;
if (!IsValidParameterHash(t, script, out ignored, out ignored)) throw new Exception("Shouldn't be possible");
var keys = new List<MemberInfo>();
var args = new List<MemberInfo>();
for (var i = 0; i < script.Arguments.Length; i++)
{
var argName = script.Arguments[i];
var member = t.GetMember(argName).Where(m => m is PropertyInfo || m is FieldInfo).SingleOrDefault();
var memberType = member is FieldInfo ? ((FieldInfo)member).FieldType : ((PropertyInfo)member).PropertyType;
if (memberType == typeof(RedisKey))
{
keys.Add(member);
}
args.Add(member);
}
var nullableRedisKeyHasValue = typeof(RedisKey?).GetProperty("HasValue").GetGetMethod();
var dyn = new DynamicMethod("ParameterExtractor_" + t.FullName + "_" + script.OriginalScript.GetHashCode(), typeof(ScriptParameters), new[] { typeof(object), typeof(RedisKey?) }, restrictedSkipVisibility: true);
var il = dyn.GetILGenerator();
// only init'd if we use it
LocalBuilder redisKeyLoc = null;
var loc = il.DeclareLocal(t);
il.Emit(OpCodes.Ldarg_0); // object
if (t.IsValueType)
{
il.Emit(OpCodes.Unbox_Any, t); // T
}
else
{
il.Emit(OpCodes.Castclass, t); // T
}
il.Emit(OpCodes.Stloc, loc); // --empty--
var needsKeyPrefixLoc = il.DeclareLocal(typeof(bool));
il.Emit(OpCodes.Ldarga_S, 1); // RedisKey?*
il.Emit(OpCodes.Call, nullableRedisKeyHasValue); // bool
il.Emit(OpCodes.Stloc, needsKeyPrefixLoc); // --empty--
if (keys.Count == 0)
{
// if there are no keys, don't allocate
il.Emit(OpCodes.Ldnull); // null
}
else
{
il.Emit(OpCodes.Ldc_I4, keys.Count); // int
il.Emit(OpCodes.Newarr, typeof(RedisKey)); // RedisKey[]
}
for (var i = 0; i < keys.Count; i++)
{
il.Emit(OpCodes.Dup); // RedisKey[] RedisKey[]
il.Emit(OpCodes.Ldc_I4, i); // RedisKey[] RedisKey[] int
if (t.IsValueType)
{
il.Emit(OpCodes.Ldloca, loc); // RedisKey[] RedisKey[] int T*
}
else
{
il.Emit(OpCodes.Ldloc, loc); // RedisKey[] RedisKey[] int T
}
LoadMember(il, keys[i]); // RedisKey[] RedisKey[] int RedisKey
PrefixIfNeeded(il, needsKeyPrefixLoc, ref redisKeyLoc); // RedisKey[] RedisKey[] int RedisKey
il.Emit(OpCodes.Stelem, typeof(RedisKey)); // RedisKey[]
}
if (args.Count == 0)
{
// if there are no args, don't allocate
il.Emit(OpCodes.Ldnull); // RedisKey[] null
}
else
{
il.Emit(OpCodes.Ldc_I4, args.Count); // RedisKey[] int
il.Emit(OpCodes.Newarr, typeof(RedisValue)); // RedisKey[] RedisValue[]
}
for (var i = 0; i < args.Count; i++)
{
il.Emit(OpCodes.Dup); // RedisKey[] RedisValue[] RedisValue[]
il.Emit(OpCodes.Ldc_I4, i); // RedisKey[] RedisValue[] RedisValue[] int
if (t.IsValueType)
{
il.Emit(OpCodes.Ldloca, loc); // RedisKey[] RedisValue[] RedisValue[] int T*
}
else
{
il.Emit(OpCodes.Ldloc, loc); // RedisKey[] RedisValue[] RedisValue[] int T
}
var member = args[i];
LoadMember(il, member); // RedisKey[] RedisValue[] RedisValue[] int memberType
ConvertToRedisValue(member, il, needsKeyPrefixLoc, ref redisKeyLoc); // RedisKey[] RedisValue[] RedisValue[] int RedisValue
il.Emit(OpCodes.Stelem, typeof(RedisValue)); // RedisKey[] RedisValue[]
}
il.Emit(OpCodes.Newobj, ScriptParameters.Cons); // ScriptParameters
il.Emit(OpCodes.Ret); // --empty--
var ret = (Func<object, RedisKey?, ScriptParameters>)dyn.CreateDelegate(typeof(Func<object, RedisKey?, ScriptParameters>));
return ret;
}
}
}
...@@ -113,6 +113,7 @@ ...@@ -113,6 +113,7 @@
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageCompletable.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageCompletable.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageQueue.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageQueue.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MigrateOptions.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MigrateOptions.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\LuaScript.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\Order.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\Order.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalBridge.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalBridge.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalConnection.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalConnection.cs" />
...@@ -136,6 +137,7 @@ ...@@ -136,6 +137,7 @@
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultBox.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultBox.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultProcessor.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultProcessor.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultType.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultType.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ScriptParameterMapper.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\SaveType.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\SaveType.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerCounters.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerCounters.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerEndPoint.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerEndPoint.cs" />
......
...@@ -107,6 +107,7 @@ ...@@ -107,6 +107,7 @@
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageCompletable.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageCompletable.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageQueue.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MessageQueue.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\MigrateOptions.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\MigrateOptions.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\LuaScript.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\Order.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\Order.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalBridge.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalBridge.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalConnection.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\PhysicalConnection.cs" />
...@@ -130,6 +131,7 @@ ...@@ -130,6 +131,7 @@
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultBox.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultBox.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultProcessor.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultProcessor.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultType.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ResultType.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ScriptParameterMapper.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\SaveType.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\SaveType.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerCounters.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerCounters.cs" />
<Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerEndPoint.cs" /> <Compile Include="..\StackExchange.Redis\StackExchange\Redis\ServerEndPoint.cs" />
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment