Commit 594ad1a9 authored by Marc Gravell's avatar Marc Gravell

A suitably evil hack to avoid exec-sync

parent 71763cb4
The Dangers of Synchronous Continuations
===
This is more of a "don't do this" guide.
When you are using the `*Async` API of StackExchange.Redis, it will hand you either a `Task` or a `Task<T>` that represents your reply when it is available. From here you can do a few things:
- you can ignore it (if you are going to do that, you should specify `CommandFlags.FireAndForget`, to reduce overhead)
- you can asynchronously `await` it
- you can synchronously `.Wait()` it (or `Task.WaitAll` or `.Result`, which do the same)
- you can add a continuation with `.ContinueWith(...)`
The last one of these has overloads that allow you to control the behavior, including one or more overloads that accept a [`TaskContinuationOptions`][1]. And one of these options is `ExecuteSynchronously`.
To put it simply: **do not use `TaskContinuationOptions.ExecuteSynchronously` here**. On other tasks, sure. But please please do not use this option on the task that StackExchange.Redis hands you. The reason
for this is that *if you do*, your continuation could end up interrupting the reader thread that is processing incoming redis data, and in a busy system blocking the reader will cause problems **very** quickly.
If you *can't* control this (and I strongly suggest you try to), then you can change `ConfigurationOptions.AllowSynchronousContinuations` to `false` when creating your `ConnectionMultiplexer` (or add `;syncCont=false` to the configuration string);
this will cause *all* tasks with continuations to be expressly moved *off* the reader thread and completed separately by the thread-pool. This *sounds* tempting, but in a busy system where the thread-pool is under heavy load, this can itself be problematic
(especially if the active workers are currently blocking waiting on responses that can't be actioned because the completions are stuck waiting for a worker - a deadlock). Unfortunately, at the current time
[there isn't much I can do about this](http://stackoverflow.com/q/22579206/23354), other than to advise you not to do it.
To be clear:
- `ContinueWith` by itself is fine
- and I'm sure there are times when `TaskContinuationOptions.ExecuteSynchronously` makes perfect sense on other tasks
- but please do not use `ContinueWith` with `TaskContinuationOptions.ExecuteSynchronously` on the tasks that StackExchange.Redis hands you
[1]: http://msdn.microsoft.com/en-us/library/system.threading.tasks.taskcontinuationoptions(v=vs.110).aspx
\ No newline at end of file
Once, there was more content here; then [a suitably evil workaround was found](http://stackoverflow.com/a/22588431/23354). This page is not
listed in the index, but remains for your curiosity.
\ No newline at end of file
......@@ -31,7 +31,6 @@ Documentation
- [Keys, Values and Channels](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/KeysValues.md) - discusses the data-types used on the API
- [Transactions](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Transactions.md) - how atomic transactions work in redis
- [Events](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Events.md) - the events available for logging / information purposes
- [The Dangers of Synchronous Continuations](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/ExecSync.md) - one important scenario to avoid
Questions and Contributions
---
......
using System.Threading.Tasks;
using System;
using System.Threading;
using System.Threading.Tasks;
using NUnit.Framework;
namespace StackExchange.Redis.Tests
......@@ -6,17 +8,111 @@ namespace StackExchange.Redis.Tests
[TestFixture]
public class TaskTests
{
#if DEBUG
[Test]
public void CheckContinuationCheck()
[TestCase(SourceOrign.NewTCS, false)]
[TestCase(SourceOrign.Create, false)]
[TestCase(SourceOrign.CreateDenyExec, true)]
public void CheckContinuationCheck(SourceOrign origin, bool expected)
{
TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();
Assert.IsTrue(TaskContinationCheck.NoContinuations(tcs.Task), "vanilla");
var source = Create<int>(origin);
Assert.AreEqual(expected, TaskSource.IsSyncSafe(source.Task));
}
tcs.Task.ContinueWith(x => { });
static TaskCompletionSource<T> Create<T>(SourceOrign origin)
{
switch (origin)
{
case SourceOrign.NewTCS: return new TaskCompletionSource<T>();
case SourceOrign.Create: return TaskSource.Create<T>(null);
case SourceOrign.CreateDenyExec: return TaskSource.CreateDenyExecSync<T>(null);
default: throw new ArgumentOutOfRangeException("origin");
}
}
[Test]
// regular framework behaviour: 2 out of 3 cause hijack
[TestCase(SourceOrign.NewTCS, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.NewTCS, AttachMode.ContinueWithExecSync, true)]
[TestCase(SourceOrign.NewTCS, AttachMode.Await, true)]
// Create is just a wrapper of ^^^; expect the same
[TestCase(SourceOrign.Create, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.Create, AttachMode.ContinueWithExecSync, true)]
[TestCase(SourceOrign.Create, AttachMode.Await, true)]
// deny exec-sync: none should cause hijack
[TestCase(SourceOrign.CreateDenyExec, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.CreateDenyExec, AttachMode.ContinueWithExecSync, false)]
[TestCase(SourceOrign.CreateDenyExec, AttachMode.Await, false)]
public void TestContinuationHijacking(SourceOrign origin, AttachMode attachMode, bool expectHijack)
{
TaskCompletionSource<int> source = Create<int>(origin);
Assert.IsFalse(TaskContinationCheck.NoContinuations(tcs.Task), "dirty");
int settingThread = Environment.CurrentManagedThreadId;
var state = new AwaitState();
state.Attach(source.Task, attachMode);
source.TrySetResult(123);
state.Wait(); // waits for the continuation to run
int from = state.Thread;
Assert.AreNotEqual(-1, from, "not set");
if (expectHijack)
{
Assert.AreEqual(settingThread, from, "expected hijack; didn't happen");
}
else
{
Assert.AreNotEqual(settingThread, from, "setter was hijacked");
}
}
public enum SourceOrign
{
NewTCS,
Create,
CreateDenyExec
}
public enum AttachMode
{
ContinueWith,
ContinueWithExecSync,
Await
}
class AwaitState
{
public int Thread { get { return continuationThread; } }
volatile int continuationThread = -1;
private ManualResetEventSlim evt = new ManualResetEventSlim();
public void Wait()
{
if (!evt.Wait(5000)) throw new TimeoutException();
}
public void Attach(Task task, AttachMode attachMode)
{
switch(attachMode)
{
case AttachMode.ContinueWith:
task.ContinueWith(Continue);
break;
case AttachMode.ContinueWithExecSync:
task.ContinueWith(Continue, TaskContinuationOptions.ExecuteSynchronously);
break;
case AttachMode.Await:
DoAwait(task);
break;
default:
throw new ArgumentOutOfRangeException("attachMode");
}
}
private void Continue(Task task)
{
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
private async void DoAwait(Task task)
{
await task;
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
}
#endif
}
}
......@@ -139,7 +139,7 @@
<Compile Include="StackExchange\Redis\SocketManager.cs" />
<Compile Include="StackExchange\Redis\SortType.cs" />
<Compile Include="StackExchange\Redis\StringSplits.cs" />
<Compile Include="StackExchange\Redis\TaskContinuationCheck.cs" />
<Compile Include="StackExchange\Redis\TaskSource.cs" />
<Compile Include="StackExchange\Redis\When.cs" />
<Compile Include="StackExchange\Redis\ShutdownMode.cs" />
<Compile Include="StackExchange\Redis\SaveType.cs" />
......
......@@ -12,7 +12,9 @@ public static Task<T> Default(object asyncState)
}
public static Task<T> FromResult(T value, object asyncState)
{
var tcs = new TaskCompletionSource<T>(asyncState);
// note we do not need to deny exec-sync here; the value will be known
// before we hand it to them
var tcs = TaskSource.Create<T>(asyncState);
tcs.SetResult(value);
return tcs.Task;
}
......
......@@ -17,17 +17,15 @@ sealed partial class CompletionManager
private readonly string name;
long completedSync, completedAsync, failedAsync;
private readonly bool allowSyncContinuations;
public CompletionManager(ConnectionMultiplexer multiplexer, string name)
{
this.multiplexer = multiplexer;
this.name = name;
this.allowSyncContinuations = multiplexer.RawConfig.AllowSynchronousContinuations;
}
public void CompleteSyncOrAsync(ICompletable operation)
{
if (operation == null) return;
if (operation.TryComplete(false, allowSyncContinuations))
if (operation.TryComplete(false))
{
multiplexer.Trace("Completed synchronously: " + operation, name);
Interlocked.Increment(ref completedSync);
......@@ -98,7 +96,7 @@ private static void AnyOrderCompletionHandler(object state)
try
{
ConnectionMultiplexer.TraceWithoutContext("Completing async (any order): " + state);
((ICompletable)state).TryComplete(true, true);
((ICompletable)state).TryComplete(true);
}
catch (Exception ex)
{
......@@ -135,7 +133,7 @@ private void ProcessAsyncCompletionQueueImpl()
try
{
multiplexer.Trace("Completing async (ordered): " + next, name);
next.TryComplete(true, allowSyncContinuations);
next.TryComplete(true);
Interlocked.Increment(ref completedAsync);
}
catch(Exception ex)
......
......@@ -20,7 +20,7 @@ public sealed class ConfigurationOptions : ICloneable
VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=",
TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", SslHostPrefix = "sslHost=",
ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=",
ChannelPrefixPrefix = "channelPrefix=", AllowSyncContinuationsPrefix = "syncCont=";
ChannelPrefixPrefix = "channelPrefix=";
private readonly EndPointCollection endpoints = new EndPointCollection();
......@@ -29,7 +29,7 @@ public sealed class ConfigurationOptions : ICloneable
/// </summary>
public RedisChannel ChannelPrefix { get;set; }
private bool? allowAdmin, abortOnConnectFail, resolveDns, allowSyncContinuations;
private bool? allowAdmin, abortOnConnectFail, resolveDns;
private string clientName, serviceName, password, tieBreaker, sslHost, configChannel;
private Version defaultVersion;
......@@ -149,13 +149,6 @@ public ConfigurationOptions()
/// </summary>
public bool AbortOnConnectFail { get { return abortOnConnectFail ?? true; } set { abortOnConnectFail = value; } }
/// <summary>
/// Gets or sets whether synchronous task continuations should be explicitly avoided (allowed by default)
/// </summary>
public bool AllowSynchronousContinuations { get { return allowSyncContinuations ?? true; } set { allowSyncContinuations = value; } }
/// <summary>
/// Parse the configuration from a comma-delimited configuration string
/// </summary>
......@@ -178,7 +171,6 @@ public ConfigurationOptions Clone()
keepAlive = keepAlive,
syncTimeout = syncTimeout,
allowAdmin = allowAdmin,
allowSyncContinuations = allowSyncContinuations,
defaultVersion = defaultVersion,
connectTimeout = connectTimeout,
password = password,
......@@ -225,7 +217,6 @@ public override string ToString()
Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail);
Append(sb, ResolveDnsPrefix, resolveDns);
Append(sb, ChannelPrefixPrefix, (string)ChannelPrefix);
Append(sb, AllowSyncContinuationsPrefix, allowSyncContinuations);
CommandMap.AppendDeltas(sb);
return sb.ToString();
}
......@@ -307,7 +298,7 @@ void Clear()
{
clientName = serviceName = password = tieBreaker = sslHost = configChannel = null;
keepAlive = syncTimeout = connectTimeout = writeBuffer = null;
allowAdmin = abortOnConnectFail = resolveDns = allowSyncContinuations = null;
allowAdmin = abortOnConnectFail = resolveDns = null;
defaultVersion = null;
endpoints.Clear();
CertificateSelection = null;
......@@ -357,11 +348,6 @@ private void DoParse(string configuration)
{
bool tmp;
if (Format.TryParseBoolean(value.Trim(), out tmp)) ResolveDns = tmp;
}
else if (IsOption(option, AllowSyncContinuationsPrefix))
{
bool tmp;
if (Format.TryParseBoolean(value.Trim(), out tmp)) AllowSynchronousContinuations = tmp;
}
else if (IsOption(option, ServiceNamePrefix))
{
......
......@@ -56,7 +56,7 @@ public ConnectionFailureType FailureType
{
get { return failureType; }
}
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations)
bool ICompletable.TryComplete(bool isAsync)
{
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
}
......
......@@ -1593,7 +1593,7 @@ internal Task<T> ExecuteAsyncImpl<T>(Message message, ResultProcessor<T> process
}
else
{
var tcs = new TaskCompletionSource<T>(state);
var tcs = TaskSource.CreateDenyExecSync<T>(state);
var source = ResultBox<T>.Get(tcs);
if (!TryPushMessageToBridge(message, processor, source, ref server))
{
......
......@@ -26,7 +26,7 @@ public EndPoint EndPoint
{
get { return endpoint; }
}
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations)
bool ICompletable.TryComplete(bool isAsync)
{
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
}
......
......@@ -36,7 +36,7 @@ public sealed class HashSlotMovedEventArgs : EventArgs, ICompletable
this.@new = @new;
}
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations)
bool ICompletable.TryComplete(bool isAsync)
{
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
}
......
......@@ -4,7 +4,7 @@ namespace StackExchange.Redis
{
interface ICompletable
{
bool TryComplete(bool isAsync, bool allowSyncContinuations);
bool TryComplete(bool isAsync);
void AppendStormLog(StringBuilder sb);
}
}
......@@ -56,7 +56,7 @@ public string Origin
{
get { return origin; }
}
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations)
bool ICompletable.TryComplete(bool isAsync)
{
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
}
......
......@@ -370,11 +370,11 @@ public override string ToString()
resultProcessor == null ? "(n/a)" : resultProcessor.GetType().Name);
}
public bool TryComplete(bool isAsync, bool allowSyncContinuations)
public bool TryComplete(bool isAsync)
{
if (resultBox != null)
{
return resultBox.TryComplete(isAsync, allowSyncContinuations);
return resultBox.TryComplete(isAsync);
}
else
{
......
......@@ -22,7 +22,7 @@ public override string ToString()
{
return (string)channel;
}
public bool TryComplete(bool isAsync, bool allowSyncContinuations)
public bool TryComplete(bool isAsync)
{
if (handler == null) return true;
if (isAsync)
......
......@@ -78,7 +78,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
}
else
{
var tcs = new TaskCompletionSource<T>(asyncState);
var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor);
task = tcs.Task;
......
......@@ -38,7 +38,7 @@ void ICompletable.AppendStormLog(StringBuilder sb)
sb.Append("event, error: ").Append(message);
}
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations)
bool ICompletable.TryComplete(bool isAsync)
{
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
}
......
......@@ -470,7 +470,8 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
if (message == null) return CompletedTask<T>.Default(asyncState);
if (message.IsFireAndForget) return CompletedTask<T>.Default(null); // F+F explicitly does not get async-state
var tcs = new TaskCompletionSource<T>(asyncState);
// no need to deny exec-sync here; will be complete before they see if
var tcs = TaskSource.Create<T>(asyncState);
ConnectionMultiplexer.ThrowFailed(tcs, ExceptionFactory.NoConnectionAvailable(message.Command));
return tcs.Task;
}
......
......@@ -72,7 +72,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
}
else
{
var tcs = new TaskCompletionSource<T>(asyncState);
var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor);
task = tcs.Task;
......
......@@ -21,7 +21,7 @@ public void SetException(Exception exception)
// this.exception = caught;
//}
}
public abstract bool TryComplete(bool isAsync, bool allowSyncContinuations);
public abstract bool TryComplete(bool isAsync);
[Conditional("DEBUG")]
protected static void IncrementAllocationCount()
......@@ -94,12 +94,12 @@ public void SetResult(T value)
this.value = value;
}
public override bool TryComplete(bool isAsync, bool allowSyncContinuations)
public override bool TryComplete(bool isAsync)
{
if (stateOrCompletionSource is TaskCompletionSource<T>)
{
var tcs = (TaskCompletionSource<T>)stateOrCompletionSource;
if (isAsync || allowSyncContinuations || TaskContinationCheck.NoContinuations(tcs.Task))
if (isAsync || TaskSource.IsSyncSafe(tcs.Task))
{
T val;
Exception ex;
......
......@@ -336,7 +336,7 @@ internal void OnHeartbeat()
internal Task<T> QueueDirectAsync<T>(Message message, ResultProcessor<T> processor, object asyncState = null, PhysicalBridge bridge = null)
{
var tcs = new TaskCompletionSource<T>(asyncState);
var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs);
message.SetSource(processor, source);
if(!(bridge ?? GetBridge(message.Command)).TryEnqueue(message, isSlave))
......
using System;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// Utility to detect continuations on tasks
/// </summary>
public static class TaskContinationCheck
{
static TaskContinationCheck()
{
NoContinuations = task => false; // assume the worst, hope for the best
try
{
var field = typeof(Task).GetField("m_continuationObject", BindingFlags.NonPublic | BindingFlags.Instance);
if (field == null)
{
System.Diagnostics.Trace.WriteLine("Expected field not found: Task.m_continuationObject");
return;
}
var method = new DynamicMethod("NoContinuations", typeof(bool), new[] { typeof(Task) },
typeof(Task), true);
var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldflda, field);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ldnull);
il.EmitCall(OpCodes.Call, typeof(Interlocked).GetMethod("CompareExchange", new[] { typeof(object).MakeByRefType(), typeof(object), typeof(object) }), null);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ceq);
il.Emit(OpCodes.Ret);
var func = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
TaskCompletionSource<int> source = new TaskCompletionSource<int>();
var before = func(source.Task);
source.Task.ContinueWith(t => { });
var after = func(source.Task);
if (!before)
{
System.Diagnostics.Trace.WriteLine("vanilla task should report true");
return;
}
if (after)
{
System.Diagnostics.Trace.WriteLine("task with continuation should report false");
return;
}
source.TrySetResult(0);
NoContinuations = func;
}
catch (Exception ex)
{
System.Diagnostics.Trace.WriteLine(ex.Message);
}
}
/// <summary>
/// Does the specified task have no continuations?
/// </summary>
public static readonly Func<Task, bool> NoContinuations;
}
}
using System;
using System.Diagnostics;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// We want to prevent callers hijacking the reader thread; this is a bit nasty, but works;
/// see http://stackoverflow.com/a/22588431/23354 for more information; a huge
/// thanks to Eli Arbel for spotting this (evin if it is pure evil)
/// </summary>
#if DEBUG
public // for the unit tests in TaskTests.cs
#endif
static class TaskSource
{
/// <summary>
/// Indicates whether the specified task will not hijack threads when results are set
/// </summary>
public static readonly Func<Task, bool> IsSyncSafe;
static Action<Task> denyExecSync;
static TaskSource()
{
try
{
var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.Instance | BindingFlags.NonPublic);
if (stateField != null)
{
var constField = typeof(Task).GetField("TASK_STATE_THREAD_WAS_ABORTED", BindingFlags.Static | BindingFlags.NonPublic);
// try to use the literal field value, but settle for hard-coded if it isn't there
int flag = constField == null ? 134217728 : (int)constField.GetRawConstantValue();
var method = new DynamicMethod("DenyExecSync", null, new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0); // [task]
il.Emit(OpCodes.Ldarg_0); // [task, task]
il.Emit(OpCodes.Ldfld, stateField); // [task, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [task, flags, 134217728]
il.Emit(OpCodes.Or); // [task, combined]
il.Emit(OpCodes.Stfld, stateField); // []
il.Emit(OpCodes.Ret);
denyExecSync = (Action<Task>)method.CreateDelegate(typeof(Action<Task>));
method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
il = method.GetILGenerator();
il.Emit(OpCodes.Ldc_I4, flag); // [134217728]
il.Emit(OpCodes.Ldarg_0); // [134217728, task]
il.Emit(OpCodes.Ldfld, stateField); // [134217728, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [134217728, flags, 134217728]
il.Emit(OpCodes.And); // [134217728, single-flag]
il.Emit(OpCodes.Ceq); // [true/false]
il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
}
}
catch(Exception ex)
{
Debug.WriteLine(ex.Message);
Trace.WriteLine(ex.Message);
}
if(denyExecSync == null)
denyExecSync = t => { }; // no-op if that fails
if (IsSyncSafe == null)
IsSyncSafe = t => false; // assume: not
}
/// <summary>
/// Create a new TaskCompletionSource that will not allow result-setting threads to be hijacked
/// </summary>
public static TaskCompletionSource<T> CreateDenyExecSync<T>(object asyncState)
{
var source = new TaskCompletionSource<T>(asyncState);
denyExecSync(source.Task);
return source;
}
/// <summary>
/// Create a new TaskCompletion source
/// </summary>
public static TaskCompletionSource<T> Create<T>(object asyncState)
{
return new TaskCompletionSource<T>(asyncState);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment