Commit 50547eba authored by Marc Gravell's avatar Marc Gravell

Special case SetOnInvokeMres (known good) rather than spoofing task abort

parent 4a02a425
using NUnit.Framework;
using System;
using System.Threading.Tasks;
namespace StackExchange.Redis.Tests.Issues
{
[TestFixture]
public class SO25567566 : TestBase
{
protected override string GetConfiguration()
{
return "127.0.0.1";
}
[Test]
public async void Execute()
{
using(var conn = ConnectionMultiplexer.Connect(GetConfiguration())) // Create())
{
for(int i = 0; i < 100; i++)
{
Assert.AreEqual("ok", await DoStuff(conn));
}
}
}
private async Task<string> DoStuff(ConnectionMultiplexer conn)
{
var db = conn.GetDatabase();
var timeout = Task.Delay(5000);
var len = db.ListLengthAsync("list");
if (await Task.WhenAny(timeout, len) != len)
{
return "Timeout getting length";
}
if ((await len) == 0)
{
db.ListRightPush("list", "foo", flags: CommandFlags.FireAndForget);
}
var tran = db.CreateTransaction();
var x = tran.ListRightPopLeftPushAsync("list", "list2");
var y = tran.SetAddAsync("set", "bar");
var z = tran.KeyExpireAsync("list2", TimeSpan.FromSeconds(60));
timeout = Task.Delay(5000);
var exec = tran.ExecuteAsync();
// SWAP THESE TWO
bool ok = await Task.WhenAny(exec, timeout) == exec;
//bool ok = true;
if (ok)
{
if (await exec)
{
await Task.WhenAll(x, y, z);
var db2 = conn.GetDatabase();
db2.HashGet("hash", "whatever");
return "ok";
}
else
{
return "Transaction aborted";
}
}
else
{
return "Timeout during exec";
}
}
}
}
...@@ -79,6 +79,7 @@ ...@@ -79,6 +79,7 @@
<Compile Include="Issues\SO22786599.cs" /> <Compile Include="Issues\SO22786599.cs" />
<Compile Include="Issues\SO23949477.cs" /> <Compile Include="Issues\SO23949477.cs" />
<Compile Include="Issues\SO24807536.cs" /> <Compile Include="Issues\SO24807536.cs" />
<Compile Include="Issues\SO25567566.cs" />
<Compile Include="Keys.cs" /> <Compile Include="Keys.cs" />
<Compile Include="KeysAndValues.cs" /> <Compile Include="KeysAndValues.cs" />
<Compile Include="Lex.cs" /> <Compile Include="Lex.cs" />
......
...@@ -20,49 +20,50 @@ static class TaskSource ...@@ -20,49 +20,50 @@ static class TaskSource
/// Indicates whether the specified task will not hijack threads when results are set /// Indicates whether the specified task will not hijack threads when results are set
/// </summary> /// </summary>
public static readonly Func<Task, bool> IsSyncSafe; public static readonly Func<Task, bool> IsSyncSafe;
private static readonly Action<Task> DenyExecSync;
static TaskSource() static TaskSource()
{ {
try try
{ {
var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.Instance | BindingFlags.NonPublic); Type taskType = typeof(Task);
if (stateField != null) FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{ {
var constField = typeof(Task).GetField("TASK_STATE_THREAD_WAS_ABORTED", BindingFlags.Static | BindingFlags.NonPublic); var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
// try to use the literal field value, but settle for hard-coded if it isn't there
int flag = constField == null ? 134217728 : (int)constField.GetRawConstantValue();
var method = new DynamicMethod("DenyExecSync", null, new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator(); var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0); // [task] var hasContinuation = il.DefineLabel();
il.Emit(OpCodes.Ldarg_0); // [task, task] il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, stateField); // [task, flags] il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Ldc_I4, flag); // [task, flags, 134217728] Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
il.Emit(OpCodes.Or); // [task, combined] // check if null
il.Emit(OpCodes.Stfld, stateField); // [] il.Emit(OpCodes.Brtrue_S, nonNull);
il.MarkLabel(goodReturn);
il.Emit(OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Ret); il.Emit(OpCodes.Ret);
DenyExecSync = (Action<Task>)method.CreateDelegate(typeof(Action<Task>));
method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true); // check if is a SetOnInvokeMres - if so, we're OK
il = method.GetILGenerator(); il.MarkLabel(nonNull);
il.Emit(OpCodes.Ldc_I4, flag); // [134217728] il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldarg_0); // [134217728, task] il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Ldfld, stateField); // [134217728, flags] il.Emit(OpCodes.Isinst, safeScenario);
il.Emit(OpCodes.Ldc_I4, flag); // [134217728, flags, 134217728] il.Emit(OpCodes.Brtrue_S, goodReturn);
il.Emit(OpCodes.And); // [134217728, single-flag]
il.Emit(OpCodes.Ceq); // [true/false] il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Ret); il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>)); IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
// and test them (check for an exception etc) // and test them (check for an exception etc)
var tcs = new TaskCompletionSource<int>(); var tcs = new TaskCompletionSource<int>();
DenyExecSync(tcs.Task); bool expectTrue = IsSyncSafe(tcs.Task);
if(!IsSyncSafe(tcs.Task)) tcs.Task.ContinueWith(delegate { });
bool expectFalse = IsSyncSafe(tcs.Task);
tcs.SetResult(0);
if(!expectTrue || expectFalse)
{ {
Debug.WriteLine("IsSyncSafe reported false!"); Debug.WriteLine("IsSyncSafe reported incorrectly!");
Trace.WriteLine("IsSyncSafe reported false!"); Trace.WriteLine("IsSyncSafe reported incorrectly!");
// revert to not trusting them // revert to not trusting /them
DenyExecSync = null;
IsSyncSafe = null; IsSyncSafe = null;
} }
} }
...@@ -71,12 +72,8 @@ static TaskSource() ...@@ -71,12 +72,8 @@ static TaskSource()
{ {
Debug.WriteLine(ex.Message); Debug.WriteLine(ex.Message);
Trace.WriteLine(ex.Message); Trace.WriteLine(ex.Message);
DenyExecSync = null;
IsSyncSafe = null; IsSyncSafe = null;
} }
if(DenyExecSync == null)
DenyExecSync = t => { }; // no-op if that fails
if (IsSyncSafe == null) if (IsSyncSafe == null)
IsSyncSafe = t => false; // assume: not IsSyncSafe = t => false; // assume: not
} }
...@@ -95,7 +92,7 @@ public static TaskCompletionSource<T> Create<T>(object asyncState) ...@@ -95,7 +92,7 @@ public static TaskCompletionSource<T> Create<T>(object asyncState)
public static TaskCompletionSource<T> CreateDenyExecSync<T>(object asyncState) public static TaskCompletionSource<T> CreateDenyExecSync<T>(object asyncState)
{ {
var source = new TaskCompletionSource<T>(asyncState); var source = new TaskCompletionSource<T>(asyncState);
DenyExecSync(source.Task); //DenyExecSync(source.Task);
return source; return source;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment