Commit d0880697 authored by Marc Gravell's avatar Marc Gravell

Simpler SSL connect

parent bc9e4fe2
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
using System.IO; using System.IO;
using System.Net; using System.Net;
using NUnit.Framework; using NUnit.Framework;
using System.Linq;
namespace StackExchange.Redis.Tests namespace StackExchange.Redis.Tests
{ {
...@@ -11,10 +12,17 @@ namespace StackExchange.Redis.Tests ...@@ -11,10 +12,17 @@ namespace StackExchange.Redis.Tests
public class SSL : TestBase public class SSL : TestBase
{ {
[Test] [Test]
[TestCase(6379, null)] [TestCase(6379, false, false)]
[TestCase(6380, "as if we care")] [TestCase(6380, true, false)]
public void ConnectToSSLServer(int port, string sslHost) [TestCase(6380, true, true)]
public void ConnectToSSLServer(int port, bool useSsl, bool specifyHost)
{ {
string host = null;
const string path = @"D:\RedisSslHost.txt"; // because I choose not to advertise my server here!
if (File.Exists(path)) host = File.ReadLines(path).First();
if (string.IsNullOrWhiteSpace(host)) Assert.Inconclusive("no ssl host specified at: " + path);
var config = new ConfigurationOptions var config = new ConfigurationOptions
{ {
CommandMap = CommandMap.Create( // looks like "config" is disabled CommandMap = CommandMap.Create( // looks like "config" is disabled
...@@ -24,18 +32,36 @@ public void ConnectToSSLServer(int port, string sslHost) ...@@ -24,18 +32,36 @@ public void ConnectToSSLServer(int port, string sslHost)
{ "cluster", null } { "cluster", null }
} }
), ),
SslHost = sslHost, EndPoints = { { host, port} },
EndPoints = { { "sslredis", port} },
AllowAdmin = true, AllowAdmin = true,
SyncTimeout = Debugger.IsAttached ? int.MaxValue : 5000 SyncTimeout = Debugger.IsAttached ? int.MaxValue : 5000
}; };
config.CertificateValidation += (sender, cert, chain, errors) => if(useSsl)
{ {
Console.WriteLine("cert issued to: " + cert.Subject); config.UseSsl = useSsl;
return true; // fingers in ears, pretend we don't know this is wrong if (specifyHost)
}; {
using (var muxer = ConnectionMultiplexer.Connect(config, Console.Out)) config.SslHost = host;
}
config.CertificateValidation += (sender, cert, chain, errors) =>
{
Console.WriteLine("errors: " + errors);
Console.WriteLine("cert issued to: " + cert.Subject);
return true; // fingers in ears, pretend we don't know this is wrong
};
}
var configString = config.ToString();
Console.WriteLine("config: " + configString);
var clone = ConfigurationOptions.Parse(configString);
Assert.AreEqual(configString, clone.ToString(), "config string");
using(var log = new StringWriter())
using (var muxer = ConnectionMultiplexer.Connect(config, log))
{ {
Console.WriteLine("Connect log:");
Console.WriteLine(log);
Console.WriteLine("====");
muxer.ConnectionFailed += OnConnectionFailed; muxer.ConnectionFailed += OnConnectionFailed;
muxer.InternalError += OnInternalError; muxer.InternalError += OnInternalError;
var db = muxer.GetDatabase(); var db = muxer.GetDatabase();
...@@ -66,13 +92,14 @@ public void ConnectToSSLServer(int port, string sslHost) ...@@ -66,13 +92,14 @@ public void ConnectToSSLServer(int port, string sslHost)
// perf: sync/multi-threaded // perf: sync/multi-threaded
TestConcurrent(db, key, 30, 10); TestConcurrent(db, key, 30, 10);
TestConcurrent(db, key, 30, 20); //TestConcurrent(db, key, 30, 20);
TestConcurrent(db, key, 30, 30); //TestConcurrent(db, key, 30, 30);
TestConcurrent(db, key, 30, 40); //TestConcurrent(db, key, 30, 40);
TestConcurrent(db, key, 30, 50); //TestConcurrent(db, key, 30, 50);
} }
} }
private static void TestConcurrent(IDatabase db, RedisKey key, int SyncLoop, int Threads) private static void TestConcurrent(IDatabase db, RedisKey key, int SyncLoop, int Threads)
{ {
long value; long value;
......
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Security; using System.Net.Security;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
/// <summary> /// <summary>
/// Specifies the proxy that is being used to communicate to redis /// Specifies the proxy that is being used to communicate to redis
/// </summary> /// </summary>
public enum Proxy public enum Proxy
{ {
/// <summary> /// <summary>
...@@ -22,25 +22,25 @@ public enum Proxy ...@@ -22,25 +22,25 @@ public enum Proxy
/// Communication via <a href="https://github.com/twitter/twemproxy">twemproxy</a> /// Communication via <a href="https://github.com/twitter/twemproxy">twemproxy</a>
/// </summary> /// </summary>
Twemproxy Twemproxy
} }
/// <summary> /// <summary>
/// The options relevant to a set of redis connections /// The options relevant to a set of redis connections
/// </summary> /// </summary>
public sealed class ConfigurationOptions : ICloneable public sealed class ConfigurationOptions : ICloneable
{ {
internal const string DefaultTieBreaker = "__Booksleeve_TieBreak", DefaultConfigurationChannel = "__Booksleeve_MasterChanged"; internal const string DefaultTieBreaker = "__Booksleeve_TieBreak", DefaultConfigurationChannel = "__Booksleeve_MasterChanged";
private const string AllowAdminPrefix = "allowAdmin=", SyncTimeoutPrefix = "syncTimeout=", private const string AllowAdminPrefix = "allowAdmin=", SyncTimeoutPrefix = "syncTimeout=",
ServiceNamePrefix = "serviceName=", ClientNamePrefix = "name=", KeepAlivePrefix = "keepAlive=", ServiceNamePrefix = "serviceName=", ClientNamePrefix = "name=", KeepAlivePrefix = "keepAlive=",
VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=", VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=",
TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", SslHostPrefix = "sslHost=", TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", UseSslPrefix = "ssl=", SslHostPrefix = "sslHost=",
ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=", ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=",
ChannelPrefixPrefix = "channelPrefix=", ProxyPrefix = "proxy="; ChannelPrefixPrefix = "channelPrefix=", ProxyPrefix = "proxy=";
private readonly EndPointCollection endpoints = new EndPointCollection(); private readonly EndPointCollection endpoints = new EndPointCollection();
private bool? allowAdmin, abortOnConnectFail, resolveDns; private bool? allowAdmin, abortOnConnectFail, resolveDns, useSsl;
private string clientName, serviceName, password, tieBreaker, sslHost, configChannel; private string clientName, serviceName, password, tieBreaker, sslHost, configChannel;
...@@ -52,42 +52,47 @@ public sealed class ConfigurationOptions : ICloneable ...@@ -52,42 +52,47 @@ public sealed class ConfigurationOptions : ICloneable
private Proxy? proxy; private Proxy? proxy;
/// <summary> /// <summary>
/// A LocalCertificateSelectionCallback delegate responsible for selecting the certificate used for authentication; note /// A LocalCertificateSelectionCallback delegate responsible for selecting the certificate used for authentication; note
/// that this cannot be specified in the configuration-string. /// that this cannot be specified in the configuration-string.
/// </summary> /// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1009:DeclareEventHandlersCorrectly")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1009:DeclareEventHandlersCorrectly")]
public event LocalCertificateSelectionCallback CertificateSelection; public event LocalCertificateSelectionCallback CertificateSelection;
/// <summary> /// <summary>
/// A RemoteCertificateValidationCallback delegate responsible for validating the certificate supplied by the remote party; note /// A RemoteCertificateValidationCallback delegate responsible for validating the certificate supplied by the remote party; note
/// that this cannot be specified in the configuration-string. /// that this cannot be specified in the configuration-string.
/// </summary> /// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1009:DeclareEventHandlersCorrectly")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1009:DeclareEventHandlersCorrectly")]
public event RemoteCertificateValidationCallback CertificateValidation; public event RemoteCertificateValidationCallback CertificateValidation;
/// <summary> /// <summary>
/// Gets or sets whether connect/configuration timeouts should be explicitly notified via a TimeoutException /// Gets or sets whether connect/configuration timeouts should be explicitly notified via a TimeoutException
/// </summary> /// </summary>
public bool AbortOnConnectFail { get { return abortOnConnectFail ?? true; } set { abortOnConnectFail = value; } } public bool AbortOnConnectFail { get { return abortOnConnectFail ?? true; } set { abortOnConnectFail = value; } }
/// <summary> /// <summary>
/// Indicates whether admin operations should be allowed /// Indicates whether admin operations should be allowed
/// </summary> /// </summary>
public bool AllowAdmin { get { return allowAdmin.GetValueOrDefault(); } set { allowAdmin = value; } } public bool AllowAdmin { get { return allowAdmin.GetValueOrDefault(); } set { allowAdmin = value; } }
/// <summary> /// <summary>
/// Automatically encodes and decodes channels /// Indicates whether the connection should be encrypted
/// </summary> /// </summary>
public RedisChannel ChannelPrefix { get;set; } public bool UseSsl { get { return useSsl.GetValueOrDefault(); } set { useSsl = value; } }
/// <summary>
/// The client name to user for all connections /// <summary>
/// </summary> /// Automatically encodes and decodes channels
/// </summary>
public RedisChannel ChannelPrefix { get;set; }
/// <summary>
/// The client name to user for all connections
/// </summary>
public string ClientName { get { return clientName; } set { clientName = value; } } public string ClientName { get { return clientName; } set { clientName = value; } }
/// <summary> /// <summary>
/// The command-map associated with this configuration /// The command-map associated with this configuration
/// </summary> /// </summary>
public CommandMap CommandMap public CommandMap CommandMap
{ {
get get
...@@ -100,7 +105,7 @@ public CommandMap CommandMap ...@@ -100,7 +105,7 @@ public CommandMap CommandMap
default: default:
return CommandMap.Default; return CommandMap.Default;
} }
} }
set set
{ {
if (value == null) throw new ArgumentNullException("value"); if (value == null) throw new ArgumentNullException("value");
...@@ -108,360 +113,367 @@ public CommandMap CommandMap ...@@ -108,360 +113,367 @@ public CommandMap CommandMap
} }
} }
/// <summary> /// <summary>
/// Channel to use for broadcasting and listening for configuration change notification /// Channel to use for broadcasting and listening for configuration change notification
/// </summary> /// </summary>
public string ConfigurationChannel { get { return configChannel ?? DefaultConfigurationChannel; } set { configChannel = value; } } public string ConfigurationChannel { get { return configChannel ?? DefaultConfigurationChannel; } set { configChannel = value; } }
/// <summary> /// <summary>
/// Specifies the time in milliseconds that should be allowed for connection /// Specifies the time in milliseconds that should be allowed for connection
/// </summary> /// </summary>
public int ConnectTimeout { get { return connectTimeout ?? SyncTimeout; } set { connectTimeout = value; } } public int ConnectTimeout { get { return connectTimeout ?? SyncTimeout; } set { connectTimeout = value; } }
/// <summary> /// <summary>
/// The server version to assume /// The server version to assume
/// </summary> /// </summary>
public Version DefaultVersion { get { return defaultVersion ?? RedisFeatures.v2_0_0; } set { defaultVersion = value; } } public Version DefaultVersion { get { return defaultVersion ?? RedisFeatures.v2_0_0; } set { defaultVersion = value; } }
/// <summary> /// <summary>
/// The endpoints defined for this configuration /// The endpoints defined for this configuration
/// </summary> /// </summary>
public EndPointCollection EndPoints { get { return endpoints; } } public EndPointCollection EndPoints { get { return endpoints; } }
/// <summary> /// <summary>
/// Specifies the time in seconds at which connections should be pinged to ensure validity /// Specifies the time in seconds at which connections should be pinged to ensure validity
/// </summary> /// </summary>
public int KeepAlive { get { return keepAlive.GetValueOrDefault(-1); } set { keepAlive = value; } } public int KeepAlive { get { return keepAlive.GetValueOrDefault(-1); } set { keepAlive = value; } }
/// <summary> /// <summary>
/// The password to use to authenticate with the server /// The password to use to authenticate with the server
/// </summary> /// </summary>
public string Password { get { return password; } set { password = value; } } public string Password { get { return password; } set { password = value; } }
/// <summary> /// <summary>
/// Indicates whether admin operations should be allowed /// Indicates whether admin operations should be allowed
/// </summary> /// </summary>
public Proxy Proxy { get { return proxy.GetValueOrDefault(); } set { proxy = value; } } public Proxy Proxy { get { return proxy.GetValueOrDefault(); } set { proxy = value; } }
/// <summary> /// <summary>
/// Indicates whether endpoints should be resolved via DNS before connecting /// Indicates whether endpoints should be resolved via DNS before connecting
/// </summary> /// </summary>
public bool ResolveDns { get { return resolveDns.GetValueOrDefault(); } set { resolveDns = value; } } public bool ResolveDns { get { return resolveDns.GetValueOrDefault(); } set { resolveDns = value; } }
/// <summary> /// <summary>
/// The service name used to resolve a service via sentinel /// The service name used to resolve a service via sentinel
/// </summary> /// </summary>
public string ServiceName { get { return serviceName; } set { serviceName = value; } } public string ServiceName { get { return serviceName; } set { serviceName = value; } }
/// <summary> /// <summary>
/// Gets or sets the SocketManager instance to be used with these options; if this is null a per-multiplexer /// Gets or sets the SocketManager instance to be used with these options; if this is null a per-multiplexer
/// SocketManager is created automatically. /// SocketManager is created automatically.
/// </summary> /// </summary>
public SocketManager SocketManager { get;set; } public SocketManager SocketManager { get;set; }
/// <summary> /// <summary>
/// The target-host to use when validating SSL certificate; setting a value here enables SSL mode /// The target-host to use when validating SSL certificate; setting a value here enables SSL mode
/// </summary> /// </summary>
public string SslHost { get { return sslHost; } set { sslHost = value; } } public string SslHost { get { return sslHost; } set { sslHost = value; } }
/// <summary> /// <summary>
/// Specifies the time in milliseconds that the system should allow for synchronous operations /// Specifies the time in milliseconds that the system should allow for synchronous operations
/// </summary> /// </summary>
public int SyncTimeout { get { return syncTimeout.GetValueOrDefault(1000); } set { syncTimeout = value; } } public int SyncTimeout { get { return syncTimeout.GetValueOrDefault(1000); } set { syncTimeout = value; } }
/// <summary> /// <summary>
/// Tie-breaker used to choose between masters (must match the endpoint exactly) /// Tie-breaker used to choose between masters (must match the endpoint exactly)
/// </summary> /// </summary>
public string TieBreaker { get { return tieBreaker ?? DefaultTieBreaker; } set { tieBreaker = value; } } public string TieBreaker { get { return tieBreaker ?? DefaultTieBreaker; } set { tieBreaker = value; } }
/// <summary> /// <summary>
/// The size of the output buffer to use /// The size of the output buffer to use
/// </summary> /// </summary>
public int WriteBuffer { get { return writeBuffer.GetValueOrDefault(4096); } set { writeBuffer = value; } } public int WriteBuffer { get { return writeBuffer.GetValueOrDefault(4096); } set { writeBuffer = value; } }
internal LocalCertificateSelectionCallback CertificateSelectionCallback { get { return CertificateSelection; } private set { CertificateSelection = value; } } internal LocalCertificateSelectionCallback CertificateSelectionCallback { get { return CertificateSelection; } private set { CertificateSelection = value; } }
// these just rip out the underlying handlers, bypassing the event accessors - needed when creating the SSL stream // these just rip out the underlying handlers, bypassing the event accessors - needed when creating the SSL stream
internal RemoteCertificateValidationCallback CertificateValidationCallback { get { return CertificateValidation; } private set { CertificateValidation = value; } } internal RemoteCertificateValidationCallback CertificateValidationCallback { get { return CertificateValidation; } private set { CertificateValidation = value; } }
/// <summary> /// <summary>
/// Parse the configuration from a comma-delimited configuration string /// Parse the configuration from a comma-delimited configuration string
/// </summary> /// </summary>
public static ConfigurationOptions Parse(string configuration) public static ConfigurationOptions Parse(string configuration)
{ {
var options = new ConfigurationOptions(); var options = new ConfigurationOptions();
options.DoParse(configuration); options.DoParse(configuration);
return options; return options;
} }
/// <summary> /// <summary>
/// Create a copy of the configuration /// Create a copy of the configuration
/// </summary> /// </summary>
public ConfigurationOptions Clone() public ConfigurationOptions Clone()
{ {
var options = new ConfigurationOptions var options = new ConfigurationOptions
{ {
clientName = clientName, clientName = clientName,
serviceName = serviceName, serviceName = serviceName,
keepAlive = keepAlive, keepAlive = keepAlive,
syncTimeout = syncTimeout, syncTimeout = syncTimeout,
allowAdmin = allowAdmin, allowAdmin = allowAdmin,
defaultVersion = defaultVersion, defaultVersion = defaultVersion,
connectTimeout = connectTimeout, connectTimeout = connectTimeout,
password = password, password = password,
tieBreaker = tieBreaker, tieBreaker = tieBreaker,
writeBuffer = writeBuffer, writeBuffer = writeBuffer,
sslHost = sslHost, useSsl = useSsl,
configChannel = configChannel, sslHost = sslHost,
abortOnConnectFail = abortOnConnectFail, configChannel = configChannel,
resolveDns = resolveDns, abortOnConnectFail = abortOnConnectFail,
proxy = proxy, resolveDns = resolveDns,
commandMap = commandMap, proxy = proxy,
CertificateValidationCallback = CertificateValidationCallback, commandMap = commandMap,
CertificateSelectionCallback = CertificateSelectionCallback, CertificateValidationCallback = CertificateValidationCallback,
ChannelPrefix = ChannelPrefix.Clone(), CertificateSelectionCallback = CertificateSelectionCallback,
SocketManager = SocketManager, ChannelPrefix = ChannelPrefix.Clone(),
}; SocketManager = SocketManager,
foreach (var item in endpoints) };
options.endpoints.Add(item); foreach (var item in endpoints)
return options; options.endpoints.Add(item);
return options;
}
}
/// <summary>
/// Returns the effective configuration string for this configuration /// <summary>
/// </summary> /// Returns the effective configuration string for this configuration
public override string ToString() /// </summary>
{ public override string ToString()
var sb = new StringBuilder(); {
foreach (var endpoint in endpoints) var sb = new StringBuilder();
{ foreach (var endpoint in endpoints)
Append(sb, Format.ToString(endpoint)); {
} Append(sb, Format.ToString(endpoint));
Append(sb, ClientNamePrefix, clientName); }
Append(sb, ServiceNamePrefix, serviceName); Append(sb, ClientNamePrefix, clientName);
Append(sb, KeepAlivePrefix, keepAlive); Append(sb, ServiceNamePrefix, serviceName);
Append(sb, SyncTimeoutPrefix, syncTimeout); Append(sb, KeepAlivePrefix, keepAlive);
Append(sb, AllowAdminPrefix, allowAdmin); Append(sb, SyncTimeoutPrefix, syncTimeout);
Append(sb, VersionPrefix, defaultVersion); Append(sb, AllowAdminPrefix, allowAdmin);
Append(sb, ConnectTimeoutPrefix, connectTimeout); Append(sb, VersionPrefix, defaultVersion);
Append(sb, PasswordPrefix, password); Append(sb, ConnectTimeoutPrefix, connectTimeout);
Append(sb, TieBreakerPrefix, tieBreaker); Append(sb, PasswordPrefix, password);
Append(sb, WriteBufferPrefix, writeBuffer); Append(sb, TieBreakerPrefix, tieBreaker);
Append(sb, SslHostPrefix, sslHost); Append(sb, WriteBufferPrefix, writeBuffer);
Append(sb, ConfigChannelPrefix, configChannel); Append(sb, UseSslPrefix, useSsl);
Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail); Append(sb, SslHostPrefix, sslHost);
Append(sb, ResolveDnsPrefix, resolveDns); Append(sb, ConfigChannelPrefix, configChannel);
Append(sb, ChannelPrefixPrefix, (string)ChannelPrefix); Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail);
Append(sb, ProxyPrefix, proxy); Append(sb, ResolveDnsPrefix, resolveDns);
if(commandMap != null) commandMap.AppendDeltas(sb); Append(sb, ChannelPrefixPrefix, (string)ChannelPrefix);
return sb.ToString(); Append(sb, ProxyPrefix, proxy);
} if(commandMap != null) commandMap.AppendDeltas(sb);
return sb.ToString();
internal bool HasDnsEndPoints() }
{
foreach (var endpoint in endpoints) if (endpoint is DnsEndPoint) return true; internal bool HasDnsEndPoints()
return false; {
} foreach (var endpoint in endpoints) if (endpoint is DnsEndPoint) return true;
return false;
internal async Task ResolveEndPointsAsync(ConnectionMultiplexer multiplexer, TextWriter log) }
{
Dictionary<string, IPAddress> cache = new Dictionary<string, IPAddress>(StringComparer.InvariantCultureIgnoreCase); internal async Task ResolveEndPointsAsync(ConnectionMultiplexer multiplexer, TextWriter log)
for (int i = 0; i < endpoints.Count; i++) {
{ Dictionary<string, IPAddress> cache = new Dictionary<string, IPAddress>(StringComparer.InvariantCultureIgnoreCase);
var dns = endpoints[i] as DnsEndPoint; for (int i = 0; i < endpoints.Count; i++)
if (dns != null) {
{ var dns = endpoints[i] as DnsEndPoint;
try if (dns != null)
{ {
IPAddress ip; try
if (dns.Host == ".") {
{ IPAddress ip;
endpoints[i] = new IPEndPoint(IPAddress.Loopback, dns.Port); if (dns.Host == ".")
} {
else if (cache.TryGetValue(dns.Host, out ip)) endpoints[i] = new IPEndPoint(IPAddress.Loopback, dns.Port);
{ // use cache }
endpoints[i] = new IPEndPoint(ip, dns.Port); else if (cache.TryGetValue(dns.Host, out ip))
} { // use cache
else endpoints[i] = new IPEndPoint(ip, dns.Port);
{ }
multiplexer.LogLocked(log, "Using DNS to resolve '{0}'...", dns.Host); else
var ips = await Dns.GetHostAddressesAsync(dns.Host).ObserveErrors().ForAwait(); {
if (ips.Length == 1) multiplexer.LogLocked(log, "Using DNS to resolve '{0}'...", dns.Host);
{ var ips = await Dns.GetHostAddressesAsync(dns.Host).ObserveErrors().ForAwait();
ip = ips[0]; if (ips.Length == 1)
multiplexer.LogLocked(log, "'{0}' => {1}", dns.Host, ip); {
cache[dns.Host] = ip; ip = ips[0];
endpoints[i] = new IPEndPoint(ip, dns.Port); multiplexer.LogLocked(log, "'{0}' => {1}", dns.Host, ip);
} cache[dns.Host] = ip;
} endpoints[i] = new IPEndPoint(ip, dns.Port);
} }
catch (Exception ex) }
{ }
multiplexer.OnInternalError(ex); catch (Exception ex)
multiplexer.LogLocked(log, ex.Message); {
} multiplexer.OnInternalError(ex);
} multiplexer.LogLocked(log, ex.Message);
} }
} }
}
static void Append(StringBuilder sb, object value) }
{
if (value == null) return; static void Append(StringBuilder sb, object value)
string s = Format.ToString(value); {
if (!string.IsNullOrWhiteSpace(s)) if (value == null) return;
{ string s = Format.ToString(value);
if (sb.Length != 0) sb.Append(','); if (!string.IsNullOrWhiteSpace(s))
sb.Append(s); {
} if (sb.Length != 0) sb.Append(',');
} sb.Append(s);
}
static void Append(StringBuilder sb, string prefix, object value) }
{
if (value == null) return; static void Append(StringBuilder sb, string prefix, object value)
string s = value.ToString(); {
if (!string.IsNullOrWhiteSpace(s)) if (value == null) return;
{ string s = value.ToString();
if (sb.Length != 0) sb.Append(','); if (!string.IsNullOrWhiteSpace(s))
sb.Append(prefix).Append(s); {
} if (sb.Length != 0) sb.Append(',');
} sb.Append(prefix).Append(s);
}
static bool IsOption(string option, string prefix) }
{
return option.StartsWith(prefix, StringComparison.InvariantCultureIgnoreCase); static bool IsOption(string option, string prefix)
} {
void Clear() return option.StartsWith(prefix, StringComparison.InvariantCultureIgnoreCase);
{ }
clientName = serviceName = password = tieBreaker = sslHost = configChannel = null; void Clear()
keepAlive = syncTimeout = connectTimeout = writeBuffer = null; {
allowAdmin = abortOnConnectFail = resolveDns = null; clientName = serviceName = password = tieBreaker = sslHost = configChannel = null;
defaultVersion = null; keepAlive = syncTimeout = connectTimeout = writeBuffer = null;
endpoints.Clear(); allowAdmin = abortOnConnectFail = resolveDns = useSsl = null;
commandMap = null; defaultVersion = null;
endpoints.Clear();
CertificateSelection = null; commandMap = null;
CertificateValidation = null;
ChannelPrefix = default(RedisChannel); CertificateSelection = null;
SocketManager = null; CertificateValidation = null;
} ChannelPrefix = default(RedisChannel);
SocketManager = null;
object ICloneable.Clone() { return Clone(); } }
private void DoParse(string configuration) object ICloneable.Clone() { return Clone(); }
{
Clear(); private void DoParse(string configuration)
if (!string.IsNullOrWhiteSpace(configuration)) {
{ Clear();
// break it down by commas if (!string.IsNullOrWhiteSpace(configuration))
var arr = configuration.Split(StringSplits.Comma); {
Dictionary<string, string> map = null; // break it down by commas
foreach (var paddedOption in arr) var arr = configuration.Split(StringSplits.Comma);
{ Dictionary<string, string> map = null;
var option = paddedOption.Trim(); foreach (var paddedOption in arr)
{
if (string.IsNullOrWhiteSpace(option)) continue; var option = paddedOption.Trim();
// check for special tokens if (string.IsNullOrWhiteSpace(option)) continue;
int idx = option.IndexOf('=');
if (idx > 0) // check for special tokens
{ int idx = option.IndexOf('=');
var value = option.Substring(idx + 1).Trim(); if (idx > 0)
if (IsOption(option, SyncTimeoutPrefix)) {
{ var value = option.Substring(idx + 1).Trim();
int tmp; if (IsOption(option, SyncTimeoutPrefix))
if (Format.TryParseInt32(value.Trim(), out tmp) && tmp > 0) SyncTimeout = tmp; {
} int tmp;
else if (IsOption(option, AllowAdminPrefix)) if (Format.TryParseInt32(value.Trim(), out tmp) && tmp > 0) SyncTimeout = tmp;
{ }
bool tmp; else if (IsOption(option, AllowAdminPrefix))
if (Format.TryParseBoolean(value.Trim(), out tmp)) AllowAdmin = tmp; {
} bool tmp;
else if (IsOption(option, AbortOnConnectFailPrefix)) if (Format.TryParseBoolean(value.Trim(), out tmp)) AllowAdmin = tmp;
{ }
bool tmp; else if (IsOption(option, AbortOnConnectFailPrefix))
if (Format.TryParseBoolean(value.Trim(), out tmp)) AbortOnConnectFail = tmp; {
} bool tmp;
else if (IsOption(option, ResolveDnsPrefix)) if (Format.TryParseBoolean(value.Trim(), out tmp)) AbortOnConnectFail = tmp;
{ }
bool tmp; else if (IsOption(option, ResolveDnsPrefix))
if (Format.TryParseBoolean(value.Trim(), out tmp)) ResolveDns = tmp; {
} bool tmp;
else if (IsOption(option, ServiceNamePrefix)) if (Format.TryParseBoolean(value.Trim(), out tmp)) ResolveDns = tmp;
{ }
ServiceName = value.Trim(); else if (IsOption(option, ServiceNamePrefix))
} {
else if (IsOption(option, ClientNamePrefix)) ServiceName = value.Trim();
{ }
ClientName = value.Trim(); else if (IsOption(option, ClientNamePrefix))
} {
else if (IsOption(option, ChannelPrefixPrefix)) ClientName = value.Trim();
{ }
ChannelPrefix = value.Trim(); else if (IsOption(option, ChannelPrefixPrefix))
} {
else if (IsOption(option, ConfigChannelPrefix)) ChannelPrefix = value.Trim();
{ }
ConfigurationChannel = value.Trim(); else if (IsOption(option, ConfigChannelPrefix))
} {
else if (IsOption(option, KeepAlivePrefix)) ConfigurationChannel = value.Trim();
{ }
int tmp; else if (IsOption(option, KeepAlivePrefix))
if (Format.TryParseInt32(value.Trim(), out tmp)) KeepAlive = tmp; {
} int tmp;
else if (IsOption(option, ConnectTimeoutPrefix)) if (Format.TryParseInt32(value.Trim(), out tmp)) KeepAlive = tmp;
{ }
int tmp; else if (IsOption(option, ConnectTimeoutPrefix))
if (Format.TryParseInt32(value.Trim(), out tmp)) ConnectTimeout = tmp; {
} int tmp;
else if (IsOption(option, VersionPrefix)) if (Format.TryParseInt32(value.Trim(), out tmp)) ConnectTimeout = tmp;
{ }
Version tmp; else if (IsOption(option, VersionPrefix))
if (Version.TryParse(value.Trim(), out tmp)) DefaultVersion = tmp; {
} Version tmp;
else if (IsOption(option, PasswordPrefix)) if (Version.TryParse(value.Trim(), out tmp)) DefaultVersion = tmp;
{ }
Password = value.Trim(); else if (IsOption(option, PasswordPrefix))
} {
else if (IsOption(option, TieBreakerPrefix)) Password = value.Trim();
{ }
TieBreaker = value.Trim(); else if (IsOption(option, TieBreakerPrefix))
} {
else if (IsOption(option, SslHostPrefix)) TieBreaker = value.Trim();
{ }
SslHost = value.Trim(); else if (IsOption(option, UseSslPrefix))
} {
else if (IsOption(option, WriteBufferPrefix)) bool tmp;
{ if (Format.TryParseBoolean(value.Trim(), out tmp)) UseSsl = tmp;
int tmp; }
if (Format.TryParseInt32(value.Trim(), out tmp)) WriteBuffer = tmp; else if (IsOption(option, SslHostPrefix))
{
SslHost = value.Trim();
}
else if (IsOption(option, WriteBufferPrefix))
{
int tmp;
if (Format.TryParseInt32(value.Trim(), out tmp)) WriteBuffer = tmp;
} else if(IsOption(option, ProxyPrefix)) } else if(IsOption(option, ProxyPrefix))
{ {
Proxy tmp; Proxy tmp;
if (Enum.TryParse(option, true, out tmp)) Proxy = tmp; if (Enum.TryParse(option, true, out tmp)) Proxy = tmp;
} }
else if(option[0]=='$') else if(option[0]=='$')
{ {
RedisCommand cmd; RedisCommand cmd;
option = option.Substring(1, idx-1); option = option.Substring(1, idx-1);
if (Enum.TryParse(option, true, out cmd)) if (Enum.TryParse(option, true, out cmd))
{ {
if (map == null) map = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase); if (map == null) map = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase);
map[option] = value; map[option] = value;
} }
} }
else else
{ {
ConnectionMultiplexer.TraceWithoutContext("Unknown configuration option:" + option); ConnectionMultiplexer.TraceWithoutContext("Unknown configuration option:" + option);
} }
} }
else else
{ {
var ep = Format.TryParseEndPoint(option); var ep = Format.TryParseEndPoint(option);
if (ep != null && !endpoints.Contains(ep)) endpoints.Add(ep); if (ep != null && !endpoints.Contains(ep)) endpoints.Add(ep);
} }
} }
if (map != null && map.Count != 0) if (map != null && map.Count != 0)
{ {
this.CommandMap = CommandMap.Create(map); this.CommandMap = CommandMap.Create(map);
} }
} }
} }
} }
} }
...@@ -79,6 +79,20 @@ internal static string ToString(EndPoint endpoint) ...@@ -79,6 +79,20 @@ internal static string ToString(EndPoint endpoint)
return dns.Host + ":" + Format.ToString(dns.Port); return dns.Host + ":" + Format.ToString(dns.Port);
} }
} }
internal static string ToStringHostOnly(EndPoint endpoint)
{
var dns = endpoint as DnsEndPoint;
if (dns != null)
{
return dns.Host;
}
var ip = endpoint as IPEndPoint;
if(ip != null)
{
return ip.Address.ToString();
}
return "";
}
internal static bool TryGetHostPort(EndPoint endpoint, out string host, out int port) internal static bool TryGetHostPort(EndPoint endpoint, out string host, out int port)
{ {
......
...@@ -522,14 +522,17 @@ SocketMode ISocketCallback.Connected(Stream stream) ...@@ -522,14 +522,17 @@ SocketMode ISocketCallback.Connected(Stream stream)
// [network]<==[ssl]<==[logging]<==[buffered] // [network]<==[ssl]<==[logging]<==[buffered]
var config = multiplexer.RawConfig; var config = multiplexer.RawConfig;
if (!string.IsNullOrWhiteSpace(config.SslHost)) if(config.UseSsl)
{ {
var host = config.SslHost;
if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(bridge.ServerEndPoint.EndPoint);
var ssl = new SslStream(stream, false, config.CertificateValidationCallback, config.CertificateSelectionCallback var ssl = new SslStream(stream, false, config.CertificateValidationCallback, config.CertificateSelectionCallback
#if !MONO #if !MONO
, EncryptionPolicy.RequireEncryption , EncryptionPolicy.RequireEncryption
#endif #endif
); );
ssl.AuthenticateAsClient(config.SslHost); ssl.AuthenticateAsClient(host);
if (!ssl.IsEncrypted) if (!ssl.IsEncrypted)
{ {
RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure); RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure);
......
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading; using System.Threading;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
internal enum SocketMode internal enum SocketMode
{ {
Abort, Abort,
Poll, Poll,
Async Async
} }
/// <summary> /// <summary>
/// Allows callbacks from SocketManager as work is discovered /// Allows callbacks from SocketManager as work is discovered
/// </summary> /// </summary>
internal interface ISocketCallback internal interface ISocketCallback
{ {
/// <summary> /// <summary>
/// Indicates that a socket has connected /// Indicates that a socket has connected
/// </summary> /// </summary>
SocketMode Connected(Stream stream); SocketMode Connected(Stream stream);
/// <summary> /// <summary>
/// Indicates that the socket has signalled an error condition /// Indicates that the socket has signalled an error condition
/// </summary> /// </summary>
void Error(); void Error();
void OnHeartbeat(); void OnHeartbeat();
/// <summary> /// <summary>
/// Indicates that data is available on the socket, and that the consumer should read synchronously from the socket while there is data /// Indicates that data is available on the socket, and that the consumer should read synchronously from the socket while there is data
/// </summary> /// </summary>
void Read(); void Read();
/// <summary> /// <summary>
/// Indicates that we cannot know whether data is available, and that the consume should commence reading asynchronously /// Indicates that we cannot know whether data is available, and that the consume should commence reading asynchronously
/// </summary> /// </summary>
void StartReading(); void StartReading();
} }
internal struct SocketToken internal struct SocketToken
{ {
internal readonly Socket Socket; internal readonly Socket Socket;
public SocketToken(Socket socket) public SocketToken(Socket socket)
{ {
this.Socket = socket; this.Socket = socket;
} }
public int Available { get { return Socket == null ? 0 : Socket.Available; } } public int Available { get { return Socket == null ? 0 : Socket.Available; } }
public bool HasValue { get { return Socket != null; } } public bool HasValue { get { return Socket != null; } }
} }
/// <summary> /// <summary>
/// A SocketManager monitors multiple sockets for availability of data; this is done using /// A SocketManager monitors multiple sockets for availability of data; this is done using
/// the Socket.Select API and a dedicated reader-thread, which allows for fast responses /// the Socket.Select API and a dedicated reader-thread, which allows for fast responses
/// even when the system is under ambient load. /// even when the system is under ambient load.
/// </summary> /// </summary>
public sealed partial class SocketManager : IDisposable public sealed partial class SocketManager : IDisposable
{ {
private static readonly ParameterizedThreadStart writeAllQueues = context => private static readonly ParameterizedThreadStart writeAllQueues = context =>
{ {
try { ((SocketManager)context).WriteAllQueues(); } catch { } try { ((SocketManager)context).WriteAllQueues(); } catch { }
}; };
private static readonly WaitCallback writeOneQueue = context => private static readonly WaitCallback writeOneQueue = context =>
{ {
try { ((SocketManager)context).WriteOneQueue(); } catch { } try { ((SocketManager)context).WriteOneQueue(); } catch { }
}; };
private readonly string name; private readonly string name;
...@@ -78,211 +78,211 @@ public sealed partial class SocketManager : IDisposable ...@@ -78,211 +78,211 @@ public sealed partial class SocketManager : IDisposable
bool isDisposed; bool isDisposed;
/// <summary> /// <summary>
/// Creates a new (optionally named) SocketManager instance /// Creates a new (optionally named) SocketManager instance
/// </summary> /// </summary>
public SocketManager(string name = null) public SocketManager(string name = null)
{ {
if (string.IsNullOrWhiteSpace(name)) name = GetType().Name; if (string.IsNullOrWhiteSpace(name)) name = GetType().Name;
this.name = name; this.name = name;
// we need a dedicated writer, because when under heavy ambient load // we need a dedicated writer, because when under heavy ambient load
// (a busy asp.net site, for example), workers are not reliable enough // (a busy asp.net site, for example), workers are not reliable enough
Thread dedicatedWriter = new Thread(writeAllQueues, 32 * 1024); // don't need a huge stack; Thread dedicatedWriter = new Thread(writeAllQueues, 32 * 1024); // don't need a huge stack;
dedicatedWriter.Priority = ThreadPriority.AboveNormal; // time critical dedicatedWriter.Priority = ThreadPriority.AboveNormal; // time critical
dedicatedWriter.Name = name + ":Write"; dedicatedWriter.Name = name + ":Write";
dedicatedWriter.IsBackground = true; // should not keep process alive dedicatedWriter.IsBackground = true; // should not keep process alive
dedicatedWriter.Start(this); // will self-exit when disposed dedicatedWriter.Start(this); // will self-exit when disposed
} }
private enum CallbackOperation private enum CallbackOperation
{ {
Read, Read,
Error Error
} }
/// <summary> /// <summary>
/// Gets the name of this SocketManager instance /// Gets the name of this SocketManager instance
/// </summary> /// </summary>
public string Name { get { return name; } } public string Name { get { return name; } }
/// <summary> /// <summary>
/// Releases all resources associated with this instance /// Releases all resources associated with this instance
/// </summary> /// </summary>
public void Dispose() public void Dispose()
{ {
lock (writeQueue) lock (writeQueue)
{ {
// make sure writer threads know to exit // make sure writer threads know to exit
isDisposed = true; isDisposed = true;
Monitor.PulseAll(writeQueue); Monitor.PulseAll(writeQueue);
} }
OnDispose(); OnDispose();
} }
internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback) internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback)
{ {
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
socket.NoDelay = true; socket.NoDelay = true;
socket.BeginConnect(endpoint, EndConnect, Tuple.Create(socket, callback)); socket.BeginConnect(endpoint, EndConnect, Tuple.Create(socket, callback));
return new SocketToken(socket); return new SocketToken(socket);
} }
internal void RequestWrite(PhysicalBridge bridge, bool forced) internal void RequestWrite(PhysicalBridge bridge, bool forced)
{ {
if (Interlocked.CompareExchange(ref bridge.inWriteQueue, 1, 0) == 0 || forced) if (Interlocked.CompareExchange(ref bridge.inWriteQueue, 1, 0) == 0 || forced)
{ {
lock (writeQueue) lock (writeQueue)
{ {
writeQueue.Enqueue(bridge); writeQueue.Enqueue(bridge);
if (writeQueue.Count == 1) if (writeQueue.Count == 1)
{ {
Monitor.PulseAll(writeQueue); Monitor.PulseAll(writeQueue);
} }
else if (writeQueue.Count >= 2) else if (writeQueue.Count >= 2)
{ // struggling are we? let's have some help dealing with the backlog { // struggling are we? let's have some help dealing with the backlog
ThreadPool.QueueUserWorkItem(writeOneQueue, this); ThreadPool.QueueUserWorkItem(writeOneQueue, this);
} }
} }
} }
} }
internal void Shutdown(SocketToken token) internal void Shutdown(SocketToken token)
{ {
Shutdown(token.Socket); Shutdown(token.Socket);
} }
private void EndConnect(IAsyncResult ar) private void EndConnect(IAsyncResult ar)
{ {
Tuple<Socket, ISocketCallback> tuple = null; Tuple<Socket, ISocketCallback> tuple = null;
try try
{ {
tuple = (Tuple<Socket, ISocketCallback>)ar.AsyncState; tuple = (Tuple<Socket, ISocketCallback>)ar.AsyncState;
var socket = tuple.Item1; var socket = tuple.Item1;
var callback = tuple.Item2; var callback = tuple.Item2;
socket.EndConnect(ar); socket.EndConnect(ar);
var netStream = new NetworkStream(socket, false); var netStream = new NetworkStream(socket, false);
var socketMode = callback == null ? SocketMode.Abort : callback.Connected(netStream); var socketMode = callback == null ? SocketMode.Abort : callback.Connected(netStream);
switch (socketMode) switch (socketMode)
{ {
case SocketMode.Poll: case SocketMode.Poll:
OnAddRead(socket, callback); OnAddRead(socket, callback);
break; break;
case SocketMode.Async: case SocketMode.Async:
try try
{ callback.StartReading(); } { callback.StartReading(); }
catch catch
{ Shutdown(socket); } { Shutdown(socket); }
break; break;
default: default:
Shutdown(socket); Shutdown(socket);
break; break;
} }
} }
catch catch
{ {
if (tuple != null) if (tuple != null)
{ {
try try
{ tuple.Item2.Error(); } { tuple.Item2.Error(); }
catch (Exception ex) catch (Exception ex)
{ {
Trace.WriteLine(ex); Trace.WriteLine(ex);
} }
} }
} }
} }
partial void OnDispose(); partial void OnDispose();
partial void OnShutdown(Socket socket); partial void OnShutdown(Socket socket);
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")]
private void Shutdown(Socket socket) private void Shutdown(Socket socket)
{ {
if (socket != null) if (socket != null)
{ {
OnShutdown(socket); OnShutdown(socket);
try { socket.Shutdown(SocketShutdown.Both); } catch { } try { socket.Shutdown(SocketShutdown.Both); } catch { }
try { socket.Close(); } catch { } try { socket.Close(); } catch { }
try { socket.Dispose(); } catch { } try { socket.Dispose(); } catch { }
} }
} }
private void WriteAllQueues() private void WriteAllQueues()
{ {
while (true) while (true)
{ {
PhysicalBridge bridge; PhysicalBridge bridge;
lock (writeQueue) lock (writeQueue)
{ {
if (writeQueue.Count == 0) if (writeQueue.Count == 0)
{ {
if (isDisposed) break; // <========= exit point if (isDisposed) break; // <========= exit point
Monitor.Wait(writeQueue); Monitor.Wait(writeQueue);
if (isDisposed) break; // (woken by Dispose) if (isDisposed) break; // (woken by Dispose)
if (writeQueue.Count == 0) continue; // still nothing... if (writeQueue.Count == 0) continue; // still nothing...
} }
bridge = writeQueue.Dequeue(); bridge = writeQueue.Dequeue();
} }
switch (bridge.WriteQueue(200)) switch (bridge.WriteQueue(200))
{ {
case WriteResult.MoreWork: case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite: case WriteResult.QueueEmptyAfterWrite:
// back of the line! // back of the line!
lock (writeQueue) lock (writeQueue)
{ {
writeQueue.Enqueue(bridge); writeQueue.Enqueue(bridge);
} }
break; break;
case WriteResult.CompetingWriter: case WriteResult.CompetingWriter:
break; break;
case WriteResult.NoConnection: case WriteResult.NoConnection:
Interlocked.Exchange(ref bridge.inWriteQueue, 0); Interlocked.Exchange(ref bridge.inWriteQueue, 0);
break; break;
case WriteResult.NothingToDo: case WriteResult.NothingToDo:
if (!bridge.ConfirmRemoveFromWriteQueue()) if (!bridge.ConfirmRemoveFromWriteQueue())
{ // more snuck in; back of the line! { // more snuck in; back of the line!
lock (writeQueue) lock (writeQueue)
{ {
writeQueue.Enqueue(bridge); writeQueue.Enqueue(bridge);
} }
} }
break; break;
} }
} }
} }
private void WriteOneQueue() private void WriteOneQueue()
{ {
PhysicalBridge bridge; PhysicalBridge bridge;
lock (writeQueue) lock (writeQueue)
{ {
bridge = writeQueue.Count == 0 ? null : writeQueue.Dequeue(); bridge = writeQueue.Count == 0 ? null : writeQueue.Dequeue();
} }
if (bridge == null) return; if (bridge == null) return;
bool keepGoing; bool keepGoing;
do do
{ {
switch (bridge.WriteQueue(-1)) switch (bridge.WriteQueue(-1))
{ {
case WriteResult.MoreWork: case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite: case WriteResult.QueueEmptyAfterWrite:
keepGoing = true; keepGoing = true;
break; break;
case WriteResult.NothingToDo: case WriteResult.NothingToDo:
keepGoing = !bridge.ConfirmRemoveFromWriteQueue(); keepGoing = !bridge.ConfirmRemoveFromWriteQueue();
break; break;
case WriteResult.CompetingWriter: case WriteResult.CompetingWriter:
keepGoing = false; keepGoing = false;
break; break;
case WriteResult.NoConnection: case WriteResult.NoConnection:
Interlocked.Exchange(ref bridge.inWriteQueue, 0); Interlocked.Exchange(ref bridge.inWriteQueue, 0);
keepGoing = false; keepGoing = false;
break; break;
default: default:
keepGoing = false; keepGoing = false;
break; break;
} }
} while (keepGoing); } while (keepGoing);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment