Commit 50547eba authored by Marc Gravell's avatar Marc Gravell

Special case SetOnInvokeMres (known good) rather than spoofing task abort

parent 4a02a425
using NUnit.Framework;
using System;
using System.Threading.Tasks;
namespace StackExchange.Redis.Tests.Issues
{
[TestFixture]
public class SO25567566 : TestBase
{
protected override string GetConfiguration()
{
return "127.0.0.1";
}
[Test]
public async void Execute()
{
using(var conn = ConnectionMultiplexer.Connect(GetConfiguration())) // Create())
{
for(int i = 0; i < 100; i++)
{
Assert.AreEqual("ok", await DoStuff(conn));
}
}
}
private async Task<string> DoStuff(ConnectionMultiplexer conn)
{
var db = conn.GetDatabase();
var timeout = Task.Delay(5000);
var len = db.ListLengthAsync("list");
if (await Task.WhenAny(timeout, len) != len)
{
return "Timeout getting length";
}
if ((await len) == 0)
{
db.ListRightPush("list", "foo", flags: CommandFlags.FireAndForget);
}
var tran = db.CreateTransaction();
var x = tran.ListRightPopLeftPushAsync("list", "list2");
var y = tran.SetAddAsync("set", "bar");
var z = tran.KeyExpireAsync("list2", TimeSpan.FromSeconds(60));
timeout = Task.Delay(5000);
var exec = tran.ExecuteAsync();
// SWAP THESE TWO
bool ok = await Task.WhenAny(exec, timeout) == exec;
//bool ok = true;
if (ok)
{
if (await exec)
{
await Task.WhenAll(x, y, z);
var db2 = conn.GetDatabase();
db2.HashGet("hash", "whatever");
return "ok";
}
else
{
return "Transaction aborted";
}
}
else
{
return "Timeout during exec";
}
}
}
}
......@@ -79,6 +79,7 @@
<Compile Include="Issues\SO22786599.cs" />
<Compile Include="Issues\SO23949477.cs" />
<Compile Include="Issues\SO24807536.cs" />
<Compile Include="Issues\SO25567566.cs" />
<Compile Include="Keys.cs" />
<Compile Include="KeysAndValues.cs" />
<Compile Include="Lex.cs" />
......
......@@ -20,49 +20,50 @@ static class TaskSource
/// Indicates whether the specified task will not hijack threads when results are set
/// </summary>
public static readonly Func<Task, bool> IsSyncSafe;
private static readonly Action<Task> DenyExecSync;
static TaskSource()
{
try
{
var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.Instance | BindingFlags.NonPublic);
if (stateField != null)
Type taskType = typeof(Task);
FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{
var constField = typeof(Task).GetField("TASK_STATE_THREAD_WAS_ABORTED", BindingFlags.Static | BindingFlags.NonPublic);
// try to use the literal field value, but settle for hard-coded if it isn't there
int flag = constField == null ? 134217728 : (int)constField.GetRawConstantValue();
var method = new DynamicMethod("DenyExecSync", null, new[] { typeof(Task) }, typeof(Task), true);
var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
il.Emit(OpCodes.Ldarg_0); // [task]
il.Emit(OpCodes.Ldarg_0); // [task, task]
il.Emit(OpCodes.Ldfld, stateField); // [task, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [task, flags, 134217728]
il.Emit(OpCodes.Or); // [task, combined]
il.Emit(OpCodes.Stfld, stateField); // []
var hasContinuation = il.DefineLabel();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
// check if null
il.Emit(OpCodes.Brtrue_S, nonNull);
il.MarkLabel(goodReturn);
il.Emit(OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Ret);
DenyExecSync = (Action<Task>)method.CreateDelegate(typeof(Action<Task>));
method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
il = method.GetILGenerator();
il.Emit(OpCodes.Ldc_I4, flag); // [134217728]
il.Emit(OpCodes.Ldarg_0); // [134217728, task]
il.Emit(OpCodes.Ldfld, stateField); // [134217728, flags]
il.Emit(OpCodes.Ldc_I4, flag); // [134217728, flags, 134217728]
il.Emit(OpCodes.And); // [134217728, single-flag]
il.Emit(OpCodes.Ceq); // [true/false]
// check if is a SetOnInvokeMres - if so, we're OK
il.MarkLabel(nonNull);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Isinst, safeScenario);
il.Emit(OpCodes.Brtrue_S, goodReturn);
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
// and test them (check for an exception etc)
var tcs = new TaskCompletionSource<int>();
DenyExecSync(tcs.Task);
if(!IsSyncSafe(tcs.Task))
bool expectTrue = IsSyncSafe(tcs.Task);
tcs.Task.ContinueWith(delegate { });
bool expectFalse = IsSyncSafe(tcs.Task);
tcs.SetResult(0);
if(!expectTrue || expectFalse)
{
Debug.WriteLine("IsSyncSafe reported false!");
Trace.WriteLine("IsSyncSafe reported false!");
// revert to not trusting them
DenyExecSync = null;
Debug.WriteLine("IsSyncSafe reported incorrectly!");
Trace.WriteLine("IsSyncSafe reported incorrectly!");
// revert to not trusting /them
IsSyncSafe = null;
}
}
......@@ -71,12 +72,8 @@ static TaskSource()
{
Debug.WriteLine(ex.Message);
Trace.WriteLine(ex.Message);
DenyExecSync = null;
IsSyncSafe = null;
}
if(DenyExecSync == null)
DenyExecSync = t => { }; // no-op if that fails
if (IsSyncSafe == null)
IsSyncSafe = t => false; // assume: not
}
......@@ -95,7 +92,7 @@ public static TaskCompletionSource<T> Create<T>(object asyncState)
public static TaskCompletionSource<T> CreateDenyExecSync<T>(object asyncState)
{
var source = new TaskCompletionSource<T>(asyncState);
DenyExecSync(source.Task);
//DenyExecSync(source.Task);
return source;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment