Commit d0880697 authored by Marc Gravell's avatar Marc Gravell

Simpler SSL connect

parent bc9e4fe2
......@@ -4,6 +4,7 @@
using System.IO;
using System.Net;
using NUnit.Framework;
using System.Linq;
namespace StackExchange.Redis.Tests
{
......@@ -11,10 +12,17 @@ namespace StackExchange.Redis.Tests
public class SSL : TestBase
{
[Test]
[TestCase(6379, null)]
[TestCase(6380, "as if we care")]
public void ConnectToSSLServer(int port, string sslHost)
[TestCase(6379, false, false)]
[TestCase(6380, true, false)]
[TestCase(6380, true, true)]
public void ConnectToSSLServer(int port, bool useSsl, bool specifyHost)
{
string host = null;
const string path = @"D:\RedisSslHost.txt"; // because I choose not to advertise my server here!
if (File.Exists(path)) host = File.ReadLines(path).First();
if (string.IsNullOrWhiteSpace(host)) Assert.Inconclusive("no ssl host specified at: " + path);
var config = new ConfigurationOptions
{
CommandMap = CommandMap.Create( // looks like "config" is disabled
......@@ -24,18 +32,36 @@ public void ConnectToSSLServer(int port, string sslHost)
{ "cluster", null }
}
),
SslHost = sslHost,
EndPoints = { { "sslredis", port} },
EndPoints = { { host, port} },
AllowAdmin = true,
SyncTimeout = Debugger.IsAttached ? int.MaxValue : 5000
};
if(useSsl)
{
config.UseSsl = useSsl;
if (specifyHost)
{
config.SslHost = host;
}
config.CertificateValidation += (sender, cert, chain, errors) =>
{
Console.WriteLine("errors: " + errors);
Console.WriteLine("cert issued to: " + cert.Subject);
return true; // fingers in ears, pretend we don't know this is wrong
};
using (var muxer = ConnectionMultiplexer.Connect(config, Console.Out))
}
var configString = config.ToString();
Console.WriteLine("config: " + configString);
var clone = ConfigurationOptions.Parse(configString);
Assert.AreEqual(configString, clone.ToString(), "config string");
using(var log = new StringWriter())
using (var muxer = ConnectionMultiplexer.Connect(config, log))
{
Console.WriteLine("Connect log:");
Console.WriteLine(log);
Console.WriteLine("====");
muxer.ConnectionFailed += OnConnectionFailed;
muxer.InternalError += OnInternalError;
var db = muxer.GetDatabase();
......@@ -66,13 +92,14 @@ public void ConnectToSSLServer(int port, string sslHost)
// perf: sync/multi-threaded
TestConcurrent(db, key, 30, 10);
TestConcurrent(db, key, 30, 20);
TestConcurrent(db, key, 30, 30);
TestConcurrent(db, key, 30, 40);
TestConcurrent(db, key, 30, 50);
//TestConcurrent(db, key, 30, 20);
//TestConcurrent(db, key, 30, 30);
//TestConcurrent(db, key, 30, 40);
//TestConcurrent(db, key, 30, 50);
}
}
private static void TestConcurrent(IDatabase db, RedisKey key, int SyncLoop, int Threads)
{
long value;
......
......@@ -34,13 +34,13 @@ public sealed class ConfigurationOptions : ICloneable
private const string AllowAdminPrefix = "allowAdmin=", SyncTimeoutPrefix = "syncTimeout=",
ServiceNamePrefix = "serviceName=", ClientNamePrefix = "name=", KeepAlivePrefix = "keepAlive=",
VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=",
TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", SslHostPrefix = "sslHost=",
TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", UseSslPrefix = "ssl=", SslHostPrefix = "sslHost=",
ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=",
ChannelPrefixPrefix = "channelPrefix=", ProxyPrefix = "proxy=";
private readonly EndPointCollection endpoints = new EndPointCollection();
private bool? allowAdmin, abortOnConnectFail, resolveDns;
private bool? allowAdmin, abortOnConnectFail, resolveDns, useSsl;
private string clientName, serviceName, password, tieBreaker, sslHost, configChannel;
......@@ -76,6 +76,11 @@ public sealed class ConfigurationOptions : ICloneable
/// </summary>
public bool AllowAdmin { get { return allowAdmin.GetValueOrDefault(); } set { allowAdmin = value; } }
/// <summary>
/// Indicates whether the connection should be encrypted
/// </summary>
public bool UseSsl { get { return useSsl.GetValueOrDefault(); } set { useSsl = value; } }
/// <summary>
/// Automatically encodes and decodes channels
/// </summary>
......@@ -208,6 +213,7 @@ public ConfigurationOptions Clone()
password = password,
tieBreaker = tieBreaker,
writeBuffer = writeBuffer,
useSsl = useSsl,
sslHost = sslHost,
configChannel = configChannel,
abortOnConnectFail = abortOnConnectFail,
......@@ -245,6 +251,7 @@ public override string ToString()
Append(sb, PasswordPrefix, password);
Append(sb, TieBreakerPrefix, tieBreaker);
Append(sb, WriteBufferPrefix, writeBuffer);
Append(sb, UseSslPrefix, useSsl);
Append(sb, SslHostPrefix, sslHost);
Append(sb, ConfigChannelPrefix, configChannel);
Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail);
......@@ -332,7 +339,7 @@ void Clear()
{
clientName = serviceName = password = tieBreaker = sslHost = configChannel = null;
keepAlive = syncTimeout = connectTimeout = writeBuffer = null;
allowAdmin = abortOnConnectFail = resolveDns = null;
allowAdmin = abortOnConnectFail = resolveDns = useSsl = null;
defaultVersion = null;
endpoints.Clear();
commandMap = null;
......@@ -423,6 +430,11 @@ private void DoParse(string configuration)
{
TieBreaker = value.Trim();
}
else if (IsOption(option, UseSslPrefix))
{
bool tmp;
if (Format.TryParseBoolean(value.Trim(), out tmp)) UseSsl = tmp;
}
else if (IsOption(option, SslHostPrefix))
{
SslHost = value.Trim();
......
......@@ -79,6 +79,20 @@ internal static string ToString(EndPoint endpoint)
return dns.Host + ":" + Format.ToString(dns.Port);
}
}
internal static string ToStringHostOnly(EndPoint endpoint)
{
var dns = endpoint as DnsEndPoint;
if (dns != null)
{
return dns.Host;
}
var ip = endpoint as IPEndPoint;
if(ip != null)
{
return ip.Address.ToString();
}
return "";
}
internal static bool TryGetHostPort(EndPoint endpoint, out string host, out int port)
{
......
......@@ -522,14 +522,17 @@ SocketMode ISocketCallback.Connected(Stream stream)
// [network]<==[ssl]<==[logging]<==[buffered]
var config = multiplexer.RawConfig;
if (!string.IsNullOrWhiteSpace(config.SslHost))
if(config.UseSsl)
{
var host = config.SslHost;
if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(bridge.ServerEndPoint.EndPoint);
var ssl = new SslStream(stream, false, config.CertificateValidationCallback, config.CertificateSelectionCallback
#if !MONO
, EncryptionPolicy.RequireEncryption
#endif
);
ssl.AuthenticateAsClient(config.SslHost);
ssl.AuthenticateAsClient(host);
if (!ssl.IsEncrypted)
{
RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment