Commit ab182e4a authored by Pavel Pochobut's avatar Pavel Pochobut

Fixed race condition in PhysicalConnection creation

parent 5ba566b1
using NUnit.Framework;
using System;
using System.Threading;
namespace StackExchange.Redis.Tests
{
[TestFixture]
public class ConnectingFailDetection : TestBase
{
#if DEBUG
[TestCase]
public void FastNoticesFailOnConnectingSync()
{
try
{
using (var muxer = Create(keepAlive: 1, connectTimeout: 10000, allowAdmin: true))
{
var conn = muxer.GetDatabase();
conn.Ping();
var server = muxer.GetServer(muxer.GetEndPoints()[0]);
muxer.AllowConnect = false;
SocketManager.ConnectCompletionType = CompletionType.Sync;
server.SimulateConnectionFailure();
Assert.IsFalse(muxer.IsConnected);
// should reconnect within 1 keepalive interval
muxer.AllowConnect = true;
Console.WriteLine("Waiting for reconnect");
Thread.Sleep(2000);
Assert.IsTrue(muxer.IsConnected);
}
ClearAmbientFailures();
}
finally
{
SocketManager.ConnectCompletionType = CompletionType.Any;
}
}
[TestCase]
public void FastNoticesFailOnConnectingAsync()
{
try
{
using (var muxer = Create(keepAlive: 1, connectTimeout: 10000, allowAdmin: true))
{
var conn = muxer.GetDatabase();
conn.Ping();
var server = muxer.GetServer(muxer.GetEndPoints()[0]);
muxer.AllowConnect = false;
SocketManager.ConnectCompletionType = CompletionType.Async;
server.SimulateConnectionFailure();
Assert.IsFalse(muxer.IsConnected);
// should reconnect within 1 keepalive interval
muxer.AllowConnect = true;
Console.WriteLine("Waiting for reconnect");
Thread.Sleep(2000);
Assert.IsTrue(muxer.IsConnected);
ClearAmbientFailures();
}
}
finally
{
SocketManager.ConnectCompletionType = CompletionType.Any;
}
}
#endif
}
}
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
<ItemGroup> <ItemGroup>
<Compile Include="AsyncTests.cs" /> <Compile Include="AsyncTests.cs" />
<Compile Include="BasicOps.cs" /> <Compile Include="BasicOps.cs" />
<Compile Include="ConnectingFailDetection.cs" />
<Compile Include="HyperLogLog.cs" /> <Compile Include="HyperLogLog.cs" />
<Compile Include="WrapperBaseTests.cs" /> <Compile Include="WrapperBaseTests.cs" />
<Compile Include="TransactionWrapperTests.cs" /> <Compile Include="TransactionWrapperTests.cs" />
......
...@@ -76,6 +76,10 @@ protected void OnInternalError(object sender, InternalErrorEventArgs e) ...@@ -76,6 +76,10 @@ protected void OnInternalError(object sender, InternalErrorEventArgs e)
volatile int expectedFailCount; volatile int expectedFailCount;
[SetUp] [SetUp]
public void Setup() public void Setup()
{
ClearAmbientFailures();
}
public void ClearAmbientFailures()
{ {
Collect(); Collect();
Interlocked.Exchange(ref failCount, 0); Interlocked.Exchange(ref failCount, 0);
......
...@@ -191,6 +191,16 @@ partial class SocketManager ...@@ -191,6 +191,16 @@ partial class SocketManager
{ {
ignore = callback.IgnoreConnect; ignore = callback.IgnoreConnect;
} }
/// <summary>
/// Completion type for BeginConnect call
/// </summary>
public static CompletionType ConnectCompletionType { get; set; }
partial void ShouldForceConnectCompletionType(ref CompletionType completionType)
{
completionType = SocketManager.ConnectCompletionType;
}
} }
partial interface ISocketCallback partial interface ISocketCallback
{ {
...@@ -253,6 +263,69 @@ bool ISocketCallback.IgnoreConnect ...@@ -253,6 +263,69 @@ bool ISocketCallback.IgnoreConnect
} }
#endif #endif
/// <summary>
/// Completion type for CompletionTypeHelper
/// </summary>
public enum CompletionType
{
/// <summary>
/// Retain original completion type (either sync or async)
/// </summary>
Any = 0,
/// <summary>
/// Force sync completion
/// </summary>
Sync = 1,
/// <summary>
/// Force async completion
/// </summary>
Async = 2
}
internal class CompletionTypeHelper
{
public static void RunWithCompletionType(Func<AsyncCallback, IAsyncResult> beginAsync, AsyncCallback callback, CompletionType completionType)
{
AsyncCallback proxyCallback;
if (completionType == CompletionType.Any)
{
proxyCallback = (ar) =>
{
if (!ar.CompletedSynchronously)
{
callback(ar);
}
};
}
else
{
proxyCallback = (ar) => { };
}
var result = beginAsync(proxyCallback);
if (completionType == CompletionType.Any && !result.CompletedSynchronously)
{
return;
}
result.AsyncWaitHandle.WaitOne();
switch (completionType)
{
case CompletionType.Async:
ThreadPool.QueueUserWorkItem((s) => { callback(result); });
break;
case CompletionType.Any:
case CompletionType.Sync:
callback(result);
break;
}
return;
}
}
#if VERBOSE #if VERBOSE
partial class ConnectionMultiplexer partial class ConnectionMultiplexer
......
...@@ -388,7 +388,7 @@ internal void OnHeartbeat(bool ifConnectedOnly) ...@@ -388,7 +388,7 @@ internal void OnHeartbeat(bool ifConnectedOnly)
long newSampleCount = Interlocked.Read(ref operationCount); long newSampleCount = Interlocked.Read(ref operationCount);
Interlocked.Exchange(ref profileLog[index % ProfileLogSamples], newSampleCount); Interlocked.Exchange(ref profileLog[index % ProfileLogSamples], newSampleCount);
Interlocked.Exchange(ref profileLastLog, newSampleCount); Interlocked.Exchange(ref profileLastLog, newSampleCount);
Trace("OnHeartbeat: " + (State)state);
switch (state) switch (state)
{ {
case (int)State.Connecting: case (int)State.Connecting:
...@@ -709,7 +709,10 @@ private PhysicalConnection GetConnection() ...@@ -709,7 +709,10 @@ private PhysicalConnection GetConnection()
{ {
Interlocked.Increment(ref socketCount); Interlocked.Increment(ref socketCount);
Interlocked.Exchange(ref connectStartTicks, Environment.TickCount); Interlocked.Exchange(ref connectStartTicks, Environment.TickCount);
// separate creation and connection for case when connection completes synchronously
// in that case PhysicalConnection will call back to PhysicalBridge, and most of PhysicalBridge methods assumes that physical is not null;
physical = new PhysicalConnection(this); physical = new PhysicalConnection(this);
physical.BeginConnect();
} }
} }
return null; return null;
......
...@@ -80,11 +80,15 @@ public PhysicalConnection(PhysicalBridge bridge) ...@@ -80,11 +80,15 @@ public PhysicalConnection(PhysicalBridge bridge)
var endpoint = bridge.ServerEndPoint.EndPoint; var endpoint = bridge.ServerEndPoint.EndPoint;
physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint); physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint);
this.bridge = bridge; this.bridge = bridge;
multiplexer.Trace("Connecting...", physicalName); OnCreateEcho();
}
public void BeginConnect()
{
var endpoint = this.bridge.ServerEndPoint.EndPoint;
multiplexer.Trace("Connecting...", physicalName);
this.socketToken = multiplexer.SocketManager.BeginConnect(endpoint, this); this.socketToken = multiplexer.SocketManager.BeginConnect(endpoint, this);
//socket.SendTimeout = socket.ReceiveTimeout = multiplexer.TimeoutMilliseconds;
OnCreateEcho();
} }
private enum ReadMode : byte private enum ReadMode : byte
......
...@@ -124,13 +124,15 @@ internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback) ...@@ -124,13 +124,15 @@ internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback)
socket.NoDelay = true; socket.NoDelay = true;
try try
{ {
var ar = socket.BeginConnect(endpoint, EndConnect, Tuple.Create(socket, callback)); CompletionType connectCompletionType = CompletionType.Any;
if (ar.CompletedSynchronously) this.ShouldForceConnectCompletionType(ref connectCompletionType);
{
ConnectionMultiplexer.TraceWithoutContext("EndConnect (sync)"); CompletionTypeHelper.RunWithCompletionType(
EndConnectImpl(ar); (cb) => socket.BeginConnect(endpoint, cb, Tuple.Create(socket, callback)),
} (ar) => EndConnectImpl(ar),
} catch (NotImplementedException ex) CompletionType.Sync);
}
catch (NotImplementedException ex)
{ {
if (!(endpoint is IPEndPoint)) if (!(endpoint is IPEndPoint))
{ {
...@@ -185,14 +187,6 @@ internal void Shutdown(SocketToken token) ...@@ -185,14 +187,6 @@ internal void Shutdown(SocketToken token)
Shutdown(token.Socket); Shutdown(token.Socket);
} }
private void EndConnect(IAsyncResult ar)
{
if (!ar.CompletedSynchronously)
{
ConnectionMultiplexer.TraceWithoutContext("EndConnect (async)");
EndConnectImpl(ar);
}
}
private void EndConnectImpl(IAsyncResult ar) private void EndConnectImpl(IAsyncResult ar)
{ {
Tuple<Socket, ISocketCallback> tuple = null; Tuple<Socket, ISocketCallback> tuple = null;
...@@ -261,6 +255,8 @@ private void EndConnectImpl(IAsyncResult ar) ...@@ -261,6 +255,8 @@ private void EndConnectImpl(IAsyncResult ar)
partial void OnShutdown(Socket socket); partial void OnShutdown(Socket socket);
partial void ShouldIgnoreConnect(ISocketCallback callback, ref bool ignore); partial void ShouldIgnoreConnect(ISocketCallback callback, ref bool ignore);
partial void ShouldForceConnectCompletionType(ref CompletionType completionType);
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")]
private void Shutdown(Socket socket) private void Shutdown(Socket socket)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment