Commit ab182e4a authored by Pavel Pochobut's avatar Pavel Pochobut

Fixed race condition in PhysicalConnection creation

parent 5ba566b1
using NUnit.Framework;
using System;
using System.Threading;
namespace StackExchange.Redis.Tests
{
[TestFixture]
public class ConnectingFailDetection : TestBase
{
#if DEBUG
[TestCase]
public void FastNoticesFailOnConnectingSync()
{
try
{
using (var muxer = Create(keepAlive: 1, connectTimeout: 10000, allowAdmin: true))
{
var conn = muxer.GetDatabase();
conn.Ping();
var server = muxer.GetServer(muxer.GetEndPoints()[0]);
muxer.AllowConnect = false;
SocketManager.ConnectCompletionType = CompletionType.Sync;
server.SimulateConnectionFailure();
Assert.IsFalse(muxer.IsConnected);
// should reconnect within 1 keepalive interval
muxer.AllowConnect = true;
Console.WriteLine("Waiting for reconnect");
Thread.Sleep(2000);
Assert.IsTrue(muxer.IsConnected);
}
ClearAmbientFailures();
}
finally
{
SocketManager.ConnectCompletionType = CompletionType.Any;
}
}
[TestCase]
public void FastNoticesFailOnConnectingAsync()
{
try
{
using (var muxer = Create(keepAlive: 1, connectTimeout: 10000, allowAdmin: true))
{
var conn = muxer.GetDatabase();
conn.Ping();
var server = muxer.GetServer(muxer.GetEndPoints()[0]);
muxer.AllowConnect = false;
SocketManager.ConnectCompletionType = CompletionType.Async;
server.SimulateConnectionFailure();
Assert.IsFalse(muxer.IsConnected);
// should reconnect within 1 keepalive interval
muxer.AllowConnect = true;
Console.WriteLine("Waiting for reconnect");
Thread.Sleep(2000);
Assert.IsTrue(muxer.IsConnected);
ClearAmbientFailures();
}
}
finally
{
SocketManager.ConnectCompletionType = CompletionType.Any;
}
}
#endif
}
}
......@@ -66,6 +66,7 @@
<ItemGroup>
<Compile Include="AsyncTests.cs" />
<Compile Include="BasicOps.cs" />
<Compile Include="ConnectingFailDetection.cs" />
<Compile Include="HyperLogLog.cs" />
<Compile Include="WrapperBaseTests.cs" />
<Compile Include="TransactionWrapperTests.cs" />
......
......@@ -76,6 +76,10 @@ protected void OnInternalError(object sender, InternalErrorEventArgs e)
volatile int expectedFailCount;
[SetUp]
public void Setup()
{
ClearAmbientFailures();
}
public void ClearAmbientFailures()
{
Collect();
Interlocked.Exchange(ref failCount, 0);
......
......@@ -191,6 +191,16 @@ partial class SocketManager
{
ignore = callback.IgnoreConnect;
}
/// <summary>
/// Completion type for BeginConnect call
/// </summary>
public static CompletionType ConnectCompletionType { get; set; }
partial void ShouldForceConnectCompletionType(ref CompletionType completionType)
{
completionType = SocketManager.ConnectCompletionType;
}
}
partial interface ISocketCallback
{
......@@ -253,6 +263,69 @@ bool ISocketCallback.IgnoreConnect
}
#endif
/// <summary>
/// Completion type for CompletionTypeHelper
/// </summary>
public enum CompletionType
{
/// <summary>
/// Retain original completion type (either sync or async)
/// </summary>
Any = 0,
/// <summary>
/// Force sync completion
/// </summary>
Sync = 1,
/// <summary>
/// Force async completion
/// </summary>
Async = 2
}
internal class CompletionTypeHelper
{
public static void RunWithCompletionType(Func<AsyncCallback, IAsyncResult> beginAsync, AsyncCallback callback, CompletionType completionType)
{
AsyncCallback proxyCallback;
if (completionType == CompletionType.Any)
{
proxyCallback = (ar) =>
{
if (!ar.CompletedSynchronously)
{
callback(ar);
}
};
}
else
{
proxyCallback = (ar) => { };
}
var result = beginAsync(proxyCallback);
if (completionType == CompletionType.Any && !result.CompletedSynchronously)
{
return;
}
result.AsyncWaitHandle.WaitOne();
switch (completionType)
{
case CompletionType.Async:
ThreadPool.QueueUserWorkItem((s) => { callback(result); });
break;
case CompletionType.Any:
case CompletionType.Sync:
callback(result);
break;
}
return;
}
}
#if VERBOSE
partial class ConnectionMultiplexer
......
......@@ -388,7 +388,7 @@ internal void OnHeartbeat(bool ifConnectedOnly)
long newSampleCount = Interlocked.Read(ref operationCount);
Interlocked.Exchange(ref profileLog[index % ProfileLogSamples], newSampleCount);
Interlocked.Exchange(ref profileLastLog, newSampleCount);
Trace("OnHeartbeat: " + (State)state);
switch (state)
{
case (int)State.Connecting:
......@@ -709,7 +709,10 @@ private PhysicalConnection GetConnection()
{
Interlocked.Increment(ref socketCount);
Interlocked.Exchange(ref connectStartTicks, Environment.TickCount);
// separate creation and connection for case when connection completes synchronously
// in that case PhysicalConnection will call back to PhysicalBridge, and most of PhysicalBridge methods assumes that physical is not null;
physical = new PhysicalConnection(this);
physical.BeginConnect();
}
}
return null;
......
......@@ -80,11 +80,15 @@ public PhysicalConnection(PhysicalBridge bridge)
var endpoint = bridge.ServerEndPoint.EndPoint;
physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint);
this.bridge = bridge;
multiplexer.Trace("Connecting...", physicalName);
OnCreateEcho();
}
public void BeginConnect()
{
var endpoint = this.bridge.ServerEndPoint.EndPoint;
multiplexer.Trace("Connecting...", physicalName);
this.socketToken = multiplexer.SocketManager.BeginConnect(endpoint, this);
//socket.SendTimeout = socket.ReceiveTimeout = multiplexer.TimeoutMilliseconds;
OnCreateEcho();
}
private enum ReadMode : byte
......
......@@ -124,13 +124,15 @@ internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback)
socket.NoDelay = true;
try
{
var ar = socket.BeginConnect(endpoint, EndConnect, Tuple.Create(socket, callback));
if (ar.CompletedSynchronously)
{
ConnectionMultiplexer.TraceWithoutContext("EndConnect (sync)");
EndConnectImpl(ar);
}
} catch (NotImplementedException ex)
CompletionType connectCompletionType = CompletionType.Any;
this.ShouldForceConnectCompletionType(ref connectCompletionType);
CompletionTypeHelper.RunWithCompletionType(
(cb) => socket.BeginConnect(endpoint, cb, Tuple.Create(socket, callback)),
(ar) => EndConnectImpl(ar),
CompletionType.Sync);
}
catch (NotImplementedException ex)
{
if (!(endpoint is IPEndPoint))
{
......@@ -185,14 +187,6 @@ internal void Shutdown(SocketToken token)
Shutdown(token.Socket);
}
private void EndConnect(IAsyncResult ar)
{
if (!ar.CompletedSynchronously)
{
ConnectionMultiplexer.TraceWithoutContext("EndConnect (async)");
EndConnectImpl(ar);
}
}
private void EndConnectImpl(IAsyncResult ar)
{
Tuple<Socket, ISocketCallback> tuple = null;
......@@ -261,6 +255,8 @@ private void EndConnectImpl(IAsyncResult ar)
partial void OnShutdown(Socket socket);
partial void ShouldIgnoreConnect(ISocketCallback callback, ref bool ignore);
partial void ShouldForceConnectCompletionType(ref CompletionType completionType);
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")]
private void Shutdown(Socket socket)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment