Commit 594ad1a9 authored by Marc Gravell's avatar Marc Gravell

A suitably evil hack to avoid exec-sync

parent 71763cb4
The Dangers of Synchronous Continuations The Dangers of Synchronous Continuations
=== ===
This is more of a "don't do this" guide. Once, there was more content here; then [a suitably evil workaround was found](http://stackoverflow.com/a/22588431/23354). This page is not
listed in the index, but remains for your curiosity.
When you are using the `*Async` API of StackExchange.Redis, it will hand you either a `Task` or a `Task<T>` that represents your reply when it is available. From here you can do a few things: \ No newline at end of file
- you can ignore it (if you are going to do that, you should specify `CommandFlags.FireAndForget`, to reduce overhead)
- you can asynchronously `await` it
- you can synchronously `.Wait()` it (or `Task.WaitAll` or `.Result`, which do the same)
- you can add a continuation with `.ContinueWith(...)`
The last one of these has overloads that allow you to control the behavior, including one or more overloads that accept a [`TaskContinuationOptions`][1]. And one of these options is `ExecuteSynchronously`.
To put it simply: **do not use `TaskContinuationOptions.ExecuteSynchronously` here**. On other tasks, sure. But please please do not use this option on the task that StackExchange.Redis hands you. The reason
for this is that *if you do*, your continuation could end up interrupting the reader thread that is processing incoming redis data, and in a busy system blocking the reader will cause problems **very** quickly.
If you *can't* control this (and I strongly suggest you try to), then you can change `ConfigurationOptions.AllowSynchronousContinuations` to `false` when creating your `ConnectionMultiplexer` (or add `;syncCont=false` to the configuration string);
this will cause *all* tasks with continuations to be expressly moved *off* the reader thread and completed separately by the thread-pool. This *sounds* tempting, but in a busy system where the thread-pool is under heavy load, this can itself be problematic
(especially if the active workers are currently blocking waiting on responses that can't be actioned because the completions are stuck waiting for a worker - a deadlock). Unfortunately, at the current time
[there isn't much I can do about this](http://stackoverflow.com/q/22579206/23354), other than to advise you not to do it.
To be clear:
- `ContinueWith` by itself is fine
- and I'm sure there are times when `TaskContinuationOptions.ExecuteSynchronously` makes perfect sense on other tasks
- but please do not use `ContinueWith` with `TaskContinuationOptions.ExecuteSynchronously` on the tasks that StackExchange.Redis hands you
[1]: http://msdn.microsoft.com/en-us/library/system.threading.tasks.taskcontinuationoptions(v=vs.110).aspx
\ No newline at end of file
...@@ -31,7 +31,6 @@ Documentation ...@@ -31,7 +31,6 @@ Documentation
- [Keys, Values and Channels](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/KeysValues.md) - discusses the data-types used on the API - [Keys, Values and Channels](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/KeysValues.md) - discusses the data-types used on the API
- [Transactions](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Transactions.md) - how atomic transactions work in redis - [Transactions](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Transactions.md) - how atomic transactions work in redis
- [Events](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Events.md) - the events available for logging / information purposes - [Events](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/Events.md) - the events available for logging / information purposes
- [The Dangers of Synchronous Continuations](https://github.com/StackExchange/StackExchange.Redis/blob/master/Docs/ExecSync.md) - one important scenario to avoid
Questions and Contributions Questions and Contributions
--- ---
......
using System.Threading.Tasks; using System;
using System.Threading;
using System.Threading.Tasks;
using NUnit.Framework; using NUnit.Framework;
namespace StackExchange.Redis.Tests namespace StackExchange.Redis.Tests
...@@ -6,17 +8,111 @@ namespace StackExchange.Redis.Tests ...@@ -6,17 +8,111 @@ namespace StackExchange.Redis.Tests
[TestFixture] [TestFixture]
public class TaskTests public class TaskTests
{ {
#if DEBUG
[Test] [Test]
public void CheckContinuationCheck() [TestCase(SourceOrign.NewTCS, false)]
[TestCase(SourceOrign.Create, false)]
[TestCase(SourceOrign.CreateDenyExec, true)]
public void CheckContinuationCheck(SourceOrign origin, bool expected)
{ {
TaskCompletionSource<int> tcs = new TaskCompletionSource<int>(); var source = Create<int>(origin);
Assert.AreEqual(expected, TaskSource.IsSyncSafe(source.Task));
Assert.IsTrue(TaskContinationCheck.NoContinuations(tcs.Task), "vanilla"); }
tcs.Task.ContinueWith(x => { }); static TaskCompletionSource<T> Create<T>(SourceOrign origin)
{
switch (origin)
{
case SourceOrign.NewTCS: return new TaskCompletionSource<T>();
case SourceOrign.Create: return TaskSource.Create<T>(null);
case SourceOrign.CreateDenyExec: return TaskSource.CreateDenyExecSync<T>(null);
default: throw new ArgumentOutOfRangeException("origin");
}
}
[Test]
// regular framework behaviour: 2 out of 3 cause hijack
[TestCase(SourceOrign.NewTCS, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.NewTCS, AttachMode.ContinueWithExecSync, true)]
[TestCase(SourceOrign.NewTCS, AttachMode.Await, true)]
// Create is just a wrapper of ^^^; expect the same
[TestCase(SourceOrign.Create, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.Create, AttachMode.ContinueWithExecSync, true)]
[TestCase(SourceOrign.Create, AttachMode.Await, true)]
// deny exec-sync: none should cause hijack
[TestCase(SourceOrign.CreateDenyExec, AttachMode.ContinueWith, false)]
[TestCase(SourceOrign.CreateDenyExec, AttachMode.ContinueWithExecSync, false)]
[TestCase(SourceOrign.CreateDenyExec, AttachMode.Await, false)]
public void TestContinuationHijacking(SourceOrign origin, AttachMode attachMode, bool expectHijack)
{
TaskCompletionSource<int> source = Create<int>(origin);
Assert.IsFalse(TaskContinationCheck.NoContinuations(tcs.Task), "dirty"); int settingThread = Environment.CurrentManagedThreadId;
var state = new AwaitState();
state.Attach(source.Task, attachMode);
source.TrySetResult(123);
state.Wait(); // waits for the continuation to run
int from = state.Thread;
Assert.AreNotEqual(-1, from, "not set");
if (expectHijack)
{
Assert.AreEqual(settingThread, from, "expected hijack; didn't happen");
}
else
{
Assert.AreNotEqual(settingThread, from, "setter was hijacked");
}
}
public enum SourceOrign
{
NewTCS,
Create,
CreateDenyExec
}
public enum AttachMode
{
ContinueWith,
ContinueWithExecSync,
Await
}
class AwaitState
{
public int Thread { get { return continuationThread; } }
volatile int continuationThread = -1;
private ManualResetEventSlim evt = new ManualResetEventSlim();
public void Wait()
{
if (!evt.Wait(5000)) throw new TimeoutException();
}
public void Attach(Task task, AttachMode attachMode)
{
switch(attachMode)
{
case AttachMode.ContinueWith:
task.ContinueWith(Continue);
break;
case AttachMode.ContinueWithExecSync:
task.ContinueWith(Continue, TaskContinuationOptions.ExecuteSynchronously);
break;
case AttachMode.Await:
DoAwait(task);
break;
default:
throw new ArgumentOutOfRangeException("attachMode");
}
}
private void Continue(Task task)
{
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
private async void DoAwait(Task task)
{
await task;
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
} }
#endif
} }
} }
...@@ -139,7 +139,7 @@ ...@@ -139,7 +139,7 @@
<Compile Include="StackExchange\Redis\SocketManager.cs" /> <Compile Include="StackExchange\Redis\SocketManager.cs" />
<Compile Include="StackExchange\Redis\SortType.cs" /> <Compile Include="StackExchange\Redis\SortType.cs" />
<Compile Include="StackExchange\Redis\StringSplits.cs" /> <Compile Include="StackExchange\Redis\StringSplits.cs" />
<Compile Include="StackExchange\Redis\TaskContinuationCheck.cs" /> <Compile Include="StackExchange\Redis\TaskSource.cs" />
<Compile Include="StackExchange\Redis\When.cs" /> <Compile Include="StackExchange\Redis\When.cs" />
<Compile Include="StackExchange\Redis\ShutdownMode.cs" /> <Compile Include="StackExchange\Redis\ShutdownMode.cs" />
<Compile Include="StackExchange\Redis\SaveType.cs" /> <Compile Include="StackExchange\Redis\SaveType.cs" />
......
...@@ -12,7 +12,9 @@ public static Task<T> Default(object asyncState) ...@@ -12,7 +12,9 @@ public static Task<T> Default(object asyncState)
} }
public static Task<T> FromResult(T value, object asyncState) public static Task<T> FromResult(T value, object asyncState)
{ {
var tcs = new TaskCompletionSource<T>(asyncState); // note we do not need to deny exec-sync here; the value will be known
// before we hand it to them
var tcs = TaskSource.Create<T>(asyncState);
tcs.SetResult(value); tcs.SetResult(value);
return tcs.Task; return tcs.Task;
} }
......
...@@ -17,17 +17,15 @@ sealed partial class CompletionManager ...@@ -17,17 +17,15 @@ sealed partial class CompletionManager
private readonly string name; private readonly string name;
long completedSync, completedAsync, failedAsync; long completedSync, completedAsync, failedAsync;
private readonly bool allowSyncContinuations;
public CompletionManager(ConnectionMultiplexer multiplexer, string name) public CompletionManager(ConnectionMultiplexer multiplexer, string name)
{ {
this.multiplexer = multiplexer; this.multiplexer = multiplexer;
this.name = name; this.name = name;
this.allowSyncContinuations = multiplexer.RawConfig.AllowSynchronousContinuations;
} }
public void CompleteSyncOrAsync(ICompletable operation) public void CompleteSyncOrAsync(ICompletable operation)
{ {
if (operation == null) return; if (operation == null) return;
if (operation.TryComplete(false, allowSyncContinuations)) if (operation.TryComplete(false))
{ {
multiplexer.Trace("Completed synchronously: " + operation, name); multiplexer.Trace("Completed synchronously: " + operation, name);
Interlocked.Increment(ref completedSync); Interlocked.Increment(ref completedSync);
...@@ -98,7 +96,7 @@ private static void AnyOrderCompletionHandler(object state) ...@@ -98,7 +96,7 @@ private static void AnyOrderCompletionHandler(object state)
try try
{ {
ConnectionMultiplexer.TraceWithoutContext("Completing async (any order): " + state); ConnectionMultiplexer.TraceWithoutContext("Completing async (any order): " + state);
((ICompletable)state).TryComplete(true, true); ((ICompletable)state).TryComplete(true);
} }
catch (Exception ex) catch (Exception ex)
{ {
...@@ -135,7 +133,7 @@ private void ProcessAsyncCompletionQueueImpl() ...@@ -135,7 +133,7 @@ private void ProcessAsyncCompletionQueueImpl()
try try
{ {
multiplexer.Trace("Completing async (ordered): " + next, name); multiplexer.Trace("Completing async (ordered): " + next, name);
next.TryComplete(true, allowSyncContinuations); next.TryComplete(true);
Interlocked.Increment(ref completedAsync); Interlocked.Increment(ref completedAsync);
} }
catch(Exception ex) catch(Exception ex)
......
...@@ -20,7 +20,7 @@ public sealed class ConfigurationOptions : ICloneable ...@@ -20,7 +20,7 @@ public sealed class ConfigurationOptions : ICloneable
VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=", VersionPrefix = "version=", ConnectTimeoutPrefix = "connectTimeout=", PasswordPrefix = "password=",
TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", SslHostPrefix = "sslHost=", TieBreakerPrefix = "tiebreaker=", WriteBufferPrefix = "writeBuffer=", SslHostPrefix = "sslHost=",
ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=", ConfigChannelPrefix = "configChannel=", AbortOnConnectFailPrefix = "abortConnect=", ResolveDnsPrefix = "resolveDns=",
ChannelPrefixPrefix = "channelPrefix=", AllowSyncContinuationsPrefix = "syncCont="; ChannelPrefixPrefix = "channelPrefix=";
private readonly EndPointCollection endpoints = new EndPointCollection(); private readonly EndPointCollection endpoints = new EndPointCollection();
...@@ -29,7 +29,7 @@ public sealed class ConfigurationOptions : ICloneable ...@@ -29,7 +29,7 @@ public sealed class ConfigurationOptions : ICloneable
/// </summary> /// </summary>
public RedisChannel ChannelPrefix { get;set; } public RedisChannel ChannelPrefix { get;set; }
private bool? allowAdmin, abortOnConnectFail, resolveDns, allowSyncContinuations; private bool? allowAdmin, abortOnConnectFail, resolveDns;
private string clientName, serviceName, password, tieBreaker, sslHost, configChannel; private string clientName, serviceName, password, tieBreaker, sslHost, configChannel;
private Version defaultVersion; private Version defaultVersion;
...@@ -149,13 +149,6 @@ public ConfigurationOptions() ...@@ -149,13 +149,6 @@ public ConfigurationOptions()
/// </summary> /// </summary>
public bool AbortOnConnectFail { get { return abortOnConnectFail ?? true; } set { abortOnConnectFail = value; } } public bool AbortOnConnectFail { get { return abortOnConnectFail ?? true; } set { abortOnConnectFail = value; } }
/// <summary>
/// Gets or sets whether synchronous task continuations should be explicitly avoided (allowed by default)
/// </summary>
public bool AllowSynchronousContinuations { get { return allowSyncContinuations ?? true; } set { allowSyncContinuations = value; } }
/// <summary> /// <summary>
/// Parse the configuration from a comma-delimited configuration string /// Parse the configuration from a comma-delimited configuration string
/// </summary> /// </summary>
...@@ -178,7 +171,6 @@ public ConfigurationOptions Clone() ...@@ -178,7 +171,6 @@ public ConfigurationOptions Clone()
keepAlive = keepAlive, keepAlive = keepAlive,
syncTimeout = syncTimeout, syncTimeout = syncTimeout,
allowAdmin = allowAdmin, allowAdmin = allowAdmin,
allowSyncContinuations = allowSyncContinuations,
defaultVersion = defaultVersion, defaultVersion = defaultVersion,
connectTimeout = connectTimeout, connectTimeout = connectTimeout,
password = password, password = password,
...@@ -225,7 +217,6 @@ public override string ToString() ...@@ -225,7 +217,6 @@ public override string ToString()
Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail); Append(sb, AbortOnConnectFailPrefix, abortOnConnectFail);
Append(sb, ResolveDnsPrefix, resolveDns); Append(sb, ResolveDnsPrefix, resolveDns);
Append(sb, ChannelPrefixPrefix, (string)ChannelPrefix); Append(sb, ChannelPrefixPrefix, (string)ChannelPrefix);
Append(sb, AllowSyncContinuationsPrefix, allowSyncContinuations);
CommandMap.AppendDeltas(sb); CommandMap.AppendDeltas(sb);
return sb.ToString(); return sb.ToString();
} }
...@@ -307,7 +298,7 @@ void Clear() ...@@ -307,7 +298,7 @@ void Clear()
{ {
clientName = serviceName = password = tieBreaker = sslHost = configChannel = null; clientName = serviceName = password = tieBreaker = sslHost = configChannel = null;
keepAlive = syncTimeout = connectTimeout = writeBuffer = null; keepAlive = syncTimeout = connectTimeout = writeBuffer = null;
allowAdmin = abortOnConnectFail = resolveDns = allowSyncContinuations = null; allowAdmin = abortOnConnectFail = resolveDns = null;
defaultVersion = null; defaultVersion = null;
endpoints.Clear(); endpoints.Clear();
CertificateSelection = null; CertificateSelection = null;
...@@ -358,11 +349,6 @@ private void DoParse(string configuration) ...@@ -358,11 +349,6 @@ private void DoParse(string configuration)
bool tmp; bool tmp;
if (Format.TryParseBoolean(value.Trim(), out tmp)) ResolveDns = tmp; if (Format.TryParseBoolean(value.Trim(), out tmp)) ResolveDns = tmp;
} }
else if (IsOption(option, AllowSyncContinuationsPrefix))
{
bool tmp;
if (Format.TryParseBoolean(value.Trim(), out tmp)) AllowSynchronousContinuations = tmp;
}
else if (IsOption(option, ServiceNamePrefix)) else if (IsOption(option, ServiceNamePrefix))
{ {
ServiceName = value.Trim(); ServiceName = value.Trim();
......
...@@ -56,7 +56,7 @@ public ConnectionFailureType FailureType ...@@ -56,7 +56,7 @@ public ConnectionFailureType FailureType
{ {
get { return failureType; } get { return failureType; }
} }
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations) bool ICompletable.TryComplete(bool isAsync)
{ {
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync); return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
} }
......
...@@ -1593,7 +1593,7 @@ internal Task<T> ExecuteAsyncImpl<T>(Message message, ResultProcessor<T> process ...@@ -1593,7 +1593,7 @@ internal Task<T> ExecuteAsyncImpl<T>(Message message, ResultProcessor<T> process
} }
else else
{ {
var tcs = new TaskCompletionSource<T>(state); var tcs = TaskSource.CreateDenyExecSync<T>(state);
var source = ResultBox<T>.Get(tcs); var source = ResultBox<T>.Get(tcs);
if (!TryPushMessageToBridge(message, processor, source, ref server)) if (!TryPushMessageToBridge(message, processor, source, ref server))
{ {
......
...@@ -26,7 +26,7 @@ public EndPoint EndPoint ...@@ -26,7 +26,7 @@ public EndPoint EndPoint
{ {
get { return endpoint; } get { return endpoint; }
} }
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations) bool ICompletable.TryComplete(bool isAsync)
{ {
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync); return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
} }
......
...@@ -36,7 +36,7 @@ public sealed class HashSlotMovedEventArgs : EventArgs, ICompletable ...@@ -36,7 +36,7 @@ public sealed class HashSlotMovedEventArgs : EventArgs, ICompletable
this.@new = @new; this.@new = @new;
} }
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations) bool ICompletable.TryComplete(bool isAsync)
{ {
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync); return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
} }
......
...@@ -4,7 +4,7 @@ namespace StackExchange.Redis ...@@ -4,7 +4,7 @@ namespace StackExchange.Redis
{ {
interface ICompletable interface ICompletable
{ {
bool TryComplete(bool isAsync, bool allowSyncContinuations); bool TryComplete(bool isAsync);
void AppendStormLog(StringBuilder sb); void AppendStormLog(StringBuilder sb);
} }
} }
...@@ -56,7 +56,7 @@ public string Origin ...@@ -56,7 +56,7 @@ public string Origin
{ {
get { return origin; } get { return origin; }
} }
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations) bool ICompletable.TryComplete(bool isAsync)
{ {
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync); return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
} }
......
...@@ -370,11 +370,11 @@ public override string ToString() ...@@ -370,11 +370,11 @@ public override string ToString()
resultProcessor == null ? "(n/a)" : resultProcessor.GetType().Name); resultProcessor == null ? "(n/a)" : resultProcessor.GetType().Name);
} }
public bool TryComplete(bool isAsync, bool allowSyncContinuations) public bool TryComplete(bool isAsync)
{ {
if (resultBox != null) if (resultBox != null)
{ {
return resultBox.TryComplete(isAsync, allowSyncContinuations); return resultBox.TryComplete(isAsync);
} }
else else
{ {
......
...@@ -22,7 +22,7 @@ public override string ToString() ...@@ -22,7 +22,7 @@ public override string ToString()
{ {
return (string)channel; return (string)channel;
} }
public bool TryComplete(bool isAsync, bool allowSyncContinuations) public bool TryComplete(bool isAsync)
{ {
if (handler == null) return true; if (handler == null) return true;
if (isAsync) if (isAsync)
......
...@@ -78,7 +78,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -78,7 +78,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
} }
else else
{ {
var tcs = new TaskCompletionSource<T>(asyncState); var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs); var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor); message.SetSource(source, processor);
task = tcs.Task; task = tcs.Task;
......
...@@ -38,7 +38,7 @@ void ICompletable.AppendStormLog(StringBuilder sb) ...@@ -38,7 +38,7 @@ void ICompletable.AppendStormLog(StringBuilder sb)
sb.Append("event, error: ").Append(message); sb.Append("event, error: ").Append(message);
} }
bool ICompletable.TryComplete(bool isAsync, bool allowSyncContinuations) bool ICompletable.TryComplete(bool isAsync)
{ {
return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync); return ConnectionMultiplexer.TryCompleteHandler(handler, sender, this, isAsync);
} }
......
...@@ -470,7 +470,8 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -470,7 +470,8 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
if (message == null) return CompletedTask<T>.Default(asyncState); if (message == null) return CompletedTask<T>.Default(asyncState);
if (message.IsFireAndForget) return CompletedTask<T>.Default(null); // F+F explicitly does not get async-state if (message.IsFireAndForget) return CompletedTask<T>.Default(null); // F+F explicitly does not get async-state
var tcs = new TaskCompletionSource<T>(asyncState); // no need to deny exec-sync here; will be complete before they see if
var tcs = TaskSource.Create<T>(asyncState);
ConnectionMultiplexer.ThrowFailed(tcs, ExceptionFactory.NoConnectionAvailable(message.Command)); ConnectionMultiplexer.ThrowFailed(tcs, ExceptionFactory.NoConnectionAvailable(message.Command));
return tcs.Task; return tcs.Task;
} }
......
...@@ -72,7 +72,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -72,7 +72,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
} }
else else
{ {
var tcs = new TaskCompletionSource<T>(asyncState); var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs); var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor); message.SetSource(source, processor);
task = tcs.Task; task = tcs.Task;
......
...@@ -21,7 +21,7 @@ public void SetException(Exception exception) ...@@ -21,7 +21,7 @@ public void SetException(Exception exception)
// this.exception = caught; // this.exception = caught;
//} //}
} }
public abstract bool TryComplete(bool isAsync, bool allowSyncContinuations); public abstract bool TryComplete(bool isAsync);
[Conditional("DEBUG")] [Conditional("DEBUG")]
protected static void IncrementAllocationCount() protected static void IncrementAllocationCount()
...@@ -94,12 +94,12 @@ public void SetResult(T value) ...@@ -94,12 +94,12 @@ public void SetResult(T value)
this.value = value; this.value = value;
} }
public override bool TryComplete(bool isAsync, bool allowSyncContinuations) public override bool TryComplete(bool isAsync)
{ {
if (stateOrCompletionSource is TaskCompletionSource<T>) if (stateOrCompletionSource is TaskCompletionSource<T>)
{ {
var tcs = (TaskCompletionSource<T>)stateOrCompletionSource; var tcs = (TaskCompletionSource<T>)stateOrCompletionSource;
if (isAsync || allowSyncContinuations || TaskContinationCheck.NoContinuations(tcs.Task)) if (isAsync || TaskSource.IsSyncSafe(tcs.Task))
{ {
T val; T val;
Exception ex; Exception ex;
......
...@@ -336,7 +336,7 @@ internal void OnHeartbeat() ...@@ -336,7 +336,7 @@ internal void OnHeartbeat()
internal Task<T> QueueDirectAsync<T>(Message message, ResultProcessor<T> processor, object asyncState = null, PhysicalBridge bridge = null) internal Task<T> QueueDirectAsync<T>(Message message, ResultProcessor<T> processor, object asyncState = null, PhysicalBridge bridge = null)
{ {
var tcs = new TaskCompletionSource<T>(asyncState); var tcs = TaskSource.CreateDenyExecSync<T>(asyncState);
var source = ResultBox<T>.Get(tcs); var source = ResultBox<T>.Get(tcs);
message.SetSource(processor, source); message.SetSource(processor, source);
if(!(bridge ?? GetBridge(message.Command)).TryEnqueue(message, isSlave)) if(!(bridge ?? GetBridge(message.Command)).TryEnqueue(message, isSlave))
......
using System;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// Utility to detect continuations on tasks
/// </summary>
public static class TaskContinationCheck
{
static TaskContinationCheck()
{
NoContinuations = task => false; // assume the worst, hope for the best
try
{
var field = typeof(Task).GetField("m_continuationObject", BindingFlags.NonPublic | BindingFlags.Instance);
if (field == null)
{
System.Diagnostics.Trace.WriteLine("Expected field not found: Task.m_continuationObject");
return;
}
var method = new DynamicMethod("NoContinuations", typeof(bool), new[] { typeof(Task) },
typeof(Task), true);
var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldflda, field);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ldnull);
il.EmitCall(OpCodes.Call, typeof(Interlocked).GetMethod("CompareExchange", new[] { typeof(object).MakeByRefType(), typeof(object), typeof(object) }), null);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ceq);
il.Emit(OpCodes.Ret);
var func = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
TaskCompletionSource<int> source = new TaskCompletionSource<int>();
var before = func(source.Task);
source.Task.ContinueWith(t => { });
var after = func(source.Task);
if (!before)
{
System.Diagnostics.Trace.WriteLine("vanilla task should report true");
return;
}
if (after)
{
System.Diagnostics.Trace.WriteLine("task with continuation should report false");
return;
}
source.TrySetResult(0);
NoContinuations = func;
}
catch (Exception ex)
{
System.Diagnostics.Trace.WriteLine(ex.Message);
}
}
/// <summary>
/// Does the specified task have no continuations?
/// </summary>
public static readonly Func<Task, bool> NoContinuations;
}
}
using System;
using System.Diagnostics;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// We want to prevent callers hijacking the reader thread; this is a bit nasty, but works;
/// see http://stackoverflow.com/a/22588431/23354 for more information; a huge
/// thanks to Eli Arbel for spotting this (evin if it is pure evil)
/// </summary>
#if DEBUG
public // for the unit tests in TaskTests.cs
#endif
static class TaskSource
{
/// <summary>
/// Indicates whether the specified task will not hijack threads when results are set
/// </summary>
public static readonly Func<Task, bool> IsSyncSafe;
static Action<Task> denyExecSync;
static TaskSource()
{
try
{
var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.Instance | BindingFlags.NonPublic);
if (stateField != null)
{
var constField = typeof(Task).GetField("TASK_STATE_THREAD_WAS_ABORTED", BindingFlags.Static | BindingFlags.NonPublic);
// try to use the literal field value, but settle for hard-coded if it isn't there
int flag = constField == null ? 134217728 : (int)constField.GetRawConstantValue();
var method = new DynamicMethod("DenyExecSync", null, new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0); // [task]
il.Emit(OpCodes.Ldarg_0); // [task, task]
il.Emit(OpCodes.Ldfld, stateField); // [task, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [task, flags, 134217728]
il.Emit(OpCodes.Or); // [task, combined]
il.Emit(OpCodes.Stfld, stateField); // []
il.Emit(OpCodes.Ret);
denyExecSync = (Action<Task>)method.CreateDelegate(typeof(Action<Task>));
method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
il = method.GetILGenerator();
il.Emit(OpCodes.Ldc_I4, flag); // [134217728]
il.Emit(OpCodes.Ldarg_0); // [134217728, task]
il.Emit(OpCodes.Ldfld, stateField); // [134217728, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [134217728, flags, 134217728]
il.Emit(OpCodes.And); // [134217728, single-flag]
il.Emit(OpCodes.Ceq); // [true/false]
il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
}
}
catch(Exception ex)
{
Debug.WriteLine(ex.Message);
Trace.WriteLine(ex.Message);
}
if(denyExecSync == null)
denyExecSync = t => { }; // no-op if that fails
if (IsSyncSafe == null)
IsSyncSafe = t => false; // assume: not
}
/// <summary>
/// Create a new TaskCompletionSource that will not allow result-setting threads to be hijacked
/// </summary>
public static TaskCompletionSource<T> CreateDenyExecSync<T>(object asyncState)
{
var source = new TaskCompletionSource<T>(asyncState);
denyExecSync(source.Task);
return source;
}
/// <summary>
/// Create a new TaskCompletion source
/// </summary>
public static TaskCompletionSource<T> Create<T>(object asyncState)
{
return new TaskCompletionSource<T>(asyncState);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment