Unverified Commit fa123558 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Merge pull request #877 from StackExchange/remove-preserve-order

remove PreserveAsyncOrder
parents f2cf3a18 7a7647a5
......@@ -14,14 +14,11 @@ public class BasicOpsTests : TestBase
{
public BasicOpsTests(ITestOutputHelper output) : base (output) { }
[Theory]
[InlineData(true)]
[InlineData(false)]
public void PingOnce(bool preserveOrder)
[Fact]
public void PingOnce()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
var task = conn.PingAsync();
......@@ -51,14 +48,11 @@ public void RapidDispose()
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void PingMany(bool preserveOrder)
[Fact]
public void PingMany()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
var tasks = new Task<TimeSpan>[10000];
......@@ -155,14 +149,11 @@ public void SetWithZeroValue()
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task GetSetAsync(bool preserveOrder)
[Fact]
public async Task GetSetAsync()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
RedisKey key = Me();
......@@ -185,14 +176,11 @@ public async Task GetSetAsync(bool preserveOrder)
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void GetSetSync(bool preserveOrder)
[Fact]
public void GetSetSync()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
RedisKey key = Me();
......@@ -298,15 +286,12 @@ public void GetWithExpiryWrongTypeSync()
}
#if DEBUG
[Theory]
[InlineData(true)]
[InlineData(false)]
public void TestQuit(bool preserveOrder)
[Fact]
public void TestQuit()
{
SetExpectedAmbientFailureCount(1);
using (var muxer = Create(allowAdmin: true))
{
muxer.PreserveAsyncOrder = preserveOrder;
var db = muxer.GetDatabase();
string key = Guid.NewGuid().ToString();
db.KeyDelete(key, CommandFlags.FireAndForget);
......@@ -315,23 +300,19 @@ public void TestQuit(bool preserveOrder)
var watch = Stopwatch.StartNew();
Assert.Throws<RedisConnectionException>(() => db.Ping());
watch.Stop();
Output.WriteLine("Time to notice quit: {0}ms ({1})", watch.ElapsedMilliseconds,
preserveOrder ? "preserve order" : "any order");
Output.WriteLine("Time to notice quit: {0}ms (any order)", watch.ElapsedMilliseconds);
Thread.Sleep(20);
Debug.WriteLine("Pinging...");
Assert.Equal(key, (string)db.StringGet(key));
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task TestSevered(bool preserveOrder)
[Fact]
public async Task TestSevered()
{
SetExpectedAmbientFailureCount(2);
using (var muxer = Create(allowAdmin: true))
{
muxer.PreserveAsyncOrder = preserveOrder;
var db = muxer.GetDatabase();
string key = Guid.NewGuid().ToString();
db.KeyDelete(key, CommandFlags.FireAndForget);
......@@ -340,8 +321,7 @@ public async Task TestSevered(bool preserveOrder)
var watch = Stopwatch.StartNew();
db.Ping();
watch.Stop();
Output.WriteLine("Time to re-establish: {0}ms ({1})", watch.ElapsedMilliseconds,
preserveOrder ? "preserve order" : "any order");
Output.WriteLine("Time to re-establish: {0}ms (any order)", watch.ElapsedMilliseconds);
await Task.Delay(2000).ForAwait();
Debug.WriteLine("Pinging...");
Assert.Equal(key, db.StringGet(key));
......@@ -349,14 +329,11 @@ public async Task TestSevered(bool preserveOrder)
}
#endif
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task IncrAsync(bool preserveOrder)
[Fact]
public async Task IncrAsync()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
RedisKey key = Me();
conn.KeyDelete(key, CommandFlags.FireAndForget);
......@@ -382,14 +359,11 @@ public async Task IncrAsync(bool preserveOrder)
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void IncrSync(bool preserveOrder)
[Fact]
public void IncrSync()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var conn = muxer.GetDatabase();
RedisKey key = Me();
conn.KeyDelete(key, CommandFlags.FireAndForget);
......
......@@ -3,6 +3,7 @@
using System.Diagnostics;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
......@@ -76,17 +77,17 @@ private void TestMassivePublish(ISubscriber conn, string channel, string caption
}
[FactLongRunning]
public async Task PubSubOrder()
public async Task PubSubGetAllAnyOrder()
{
using (var muxer = GetRemoteConnection(waitForOpen: true))
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
const string channel = "PubSubOrder";
const int count = 500000;
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
muxer.PreserveAsyncOrder = true;
var data = new HashSet<int>();
await sub.SubscribeAsync(channel, (key, val) =>
{
bool pulse;
......@@ -94,7 +95,7 @@ public async Task PubSubOrder()
{
data.Add(int.Parse(Encoding.UTF8.GetString(val)));
pulse = data.Count == count;
if ((data.Count % 10) == 99) Output.WriteLine(data.Count.ToString());
if ((data.Count % 100) == 99) Output.WriteLine(data.Count.ToString());
}
if (pulse)
{
......@@ -117,10 +118,197 @@ public async Task PubSubOrder()
throw new TimeoutException("Items: " + data.Count);
}
for (int i = 0; i < count; i++)
{
Assert.Contains(i, data);
}
}
}
}
[Fact]
public async Task PubSubGetAllCorrectOrder()
{
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
var subChannel = await sub.SubscribeAsync(channel);
await sub.PingAsync();
async Task RunLoop()
{
while (!subChannel.IsCompleted)
{
var work = await subChannel.ReadAsync();
int i = int.Parse(Encoding.UTF8.GetString(work.Value));
lock (data)
{
data.Add(i);
if (data.Count == count) break;
if ((data.Count % 100) == 99) Output.WriteLine(data.Count.ToString());
}
}
lock (syncLock)
{
Monitor.PulseAll(syncLock);
}
}
lock (syncLock)
{
Task.Run(RunLoop);
for (int i = 0; i < count; i++)
{
sub.Publish(channel, i.ToString(), CommandFlags.FireAndForget);
}
if (!Monitor.Wait(syncLock, 20000))
{
throw new TimeoutException("Items: " + data.Count);
}
subChannel.Unsubscribe();
sub.Ping();
muxer.GetDatabase().Ping();
for (int i = 0; i < count; i++)
{
Assert.Equal(i, data[i]);
}
}
Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync();
});
}
}
[Fact]
public async Task PubSubGetAllCorrectOrder_OnMessage_Sync()
{
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
var subChannel = await sub.SubscribeAsync(channel);
subChannel.OnMessage((key, val) =>
{
int i = int.Parse(Encoding.UTF8.GetString(val));
bool pulse = false;
lock (data)
{
data.Add(i);
if (data.Count == count) pulse = true;
if ((data.Count % 100) == 99) Output.WriteLine(data.Count.ToString());
}
if (pulse)
{
lock (syncLock)
{
Monitor.PulseAll(syncLock);
}
}
});
await sub.PingAsync();
lock (syncLock)
{
for (int i = 0; i < count; i++)
{
sub.Publish(channel, i.ToString(), CommandFlags.FireAndForget);
}
if (!Monitor.Wait(syncLock, 20000))
{
throw new TimeoutException("Items: " + data.Count);
}
subChannel.Unsubscribe();
sub.Ping();
muxer.GetDatabase().Ping();
for (int i = 0; i < count; i++)
{
Assert.Equal(i, data[i]);
}
}
Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync();
});
}
}
[Fact]
public async Task PubSubGetAllCorrectOrder_OnMessage_Async()
{
using (var muxer = GetRemoteConnection(waitForOpen: true,
syncTimeout: 20000))
{
var sub = muxer.GetSubscriber();
RedisChannel channel = Me();
const int count = 1000;
var syncLock = new object();
var data = new List<int>(count);
var subChannel = await sub.SubscribeAsync(channel);
subChannel.OnMessage((key, val) =>
{
int i = int.Parse(Encoding.UTF8.GetString(val));
bool pulse = false;
lock (data)
{
data.Add(i);
if (data.Count == count) pulse = true;
if ((data.Count % 100) == 99) Output.WriteLine(data.Count.ToString());
}
if (pulse)
{
lock (syncLock)
{
Monitor.PulseAll(syncLock);
}
}
return i % 2 == 0 ? null : Task.CompletedTask;
});
await sub.PingAsync();
lock (syncLock)
{
for (int i = 0; i < count; i++)
{
sub.Publish(channel, i.ToString(), CommandFlags.FireAndForget);
}
if (!Monitor.Wait(syncLock, 20000))
{
throw new TimeoutException("Items: " + data.Count);
}
subChannel.Unsubscribe();
sub.Ping();
muxer.GetDatabase().Ping();
for (int i = 0; i < count; i++)
{
Assert.Equal(i, data[i]);
}
}
Assert.True(subChannel.IsCompleted);
await Assert.ThrowsAsync<ChannelClosedException>(async delegate
{
var final = await subChannel.ReadAsync();
});
}
}
......
......@@ -10,27 +10,26 @@ public class Issue791 : TestBase
[Fact]
public void PreserveAsyncOrderImplicitValue_ParsedFromConnectionString()
{
// We only care that it parses successfully while deprecated
var options = ConfigurationOptions.Parse("preserveAsyncOrder=true");
Assert.True(options.PreserveAsyncOrder);
Assert.Equal("preserveAsyncOrder=True", options.ToString());
Assert.Equal("", options.ToString());
// We only care that it parses successfully while deprecated
options = ConfigurationOptions.Parse("preserveAsyncOrder=false");
Assert.False(options.PreserveAsyncOrder);
Assert.Equal("preserveAsyncOrder=False", options.ToString());
Assert.Equal("", options.ToString());
}
[Fact]
public void DefaultValue_IsTrue()
{
var options = ConfigurationOptions.Parse("ssl=true");
Assert.True(options.PreserveAsyncOrder);
}
[Fact]
public void PreserveAsyncOrder_SetConnectionMultiplexerProperty()
{
// We only care that it parses successfully while deprecated
var multiplexer = ConnectionMultiplexer.Connect(TestConfig.Current.MasterServerAndPort + ",preserveAsyncOrder=false");
Assert.False(multiplexer.PreserveAsyncOrder);
}
}
}
......@@ -13,18 +13,12 @@ public class MassiveOps : TestBase
public MassiveOps(ITestOutputHelper output) : base(output) { }
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task MassiveBulkOpsAsync(bool preserveOrder, bool withContinuation)
[InlineData(true)]
[InlineData(false)]
public async Task MassiveBulkOpsAsync(bool withContinuation)
{
#if DEBUG
var oldAsyncCompletionCount = ConnectionMultiplexer.GetAsyncCompletionWorkerCount();
#endif
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
RedisKey key = "MBOA";
var conn = muxer.GetDatabase();
await conn.PingAsync().ForAwait();
......@@ -42,36 +36,26 @@ public async Task MassiveBulkOpsAsync(bool preserveOrder, bool withContinuation)
}
Assert.Equal(AsyncOpsQty, await conn.StringGetAsync(key).ForAwait());
watch.Stop();
Output.WriteLine("{2}: Time for {0} ops: {1}ms ({3}, {4}); ops/s: {5}", AsyncOpsQty, watch.ElapsedMilliseconds, Me(),
withContinuation ? "with continuation" : "no continuation", preserveOrder ? "preserve order" : "any order",
AsyncOpsQty / watch.Elapsed.TotalSeconds);
#if DEBUG
Output.WriteLine("Async completion workers: " + (ConnectionMultiplexer.GetAsyncCompletionWorkerCount() - oldAsyncCompletionCount));
#endif
Output.WriteLine("{2}: Time for {0} ops: {1}ms ({3}, any order); ops/s: {4}", AsyncOpsQty, watch.ElapsedMilliseconds, Me(),
withContinuation ? "with continuation" : "no continuation", AsyncOpsQty / watch.Elapsed.TotalSeconds);
}
}
[Theory]
[InlineData(true, 1)]
[InlineData(false, 1)]
[InlineData(true, 5)]
[InlineData(false, 5)]
[InlineData(true, 10)]
[InlineData(false, 10)]
[InlineData(true, 50)]
[InlineData(false, 50)]
public void MassiveBulkOpsSync(bool preserveOrder, int threads)
[InlineData(1)]
[InlineData(5)]
[InlineData(10)]
[InlineData(50)]
public void MassiveBulkOpsSync(int threads)
{
int workPerThread = SyncOpsQty / threads;
using (var muxer = Create(syncTimeout: 30000))
{
muxer.PreserveAsyncOrder = preserveOrder;
RedisKey key = "MBOS";
var conn = muxer.GetDatabase();
conn.KeyDelete(key);
#if DEBUG
long oldAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
long oldWorkerCount = ConnectionMultiplexer.GetAsyncCompletionWorkerCount();
#endif
var timeTaken = RunConcurrent(delegate
{
......@@ -83,28 +67,23 @@ public void MassiveBulkOpsSync(bool preserveOrder, int threads)
int val = (int)conn.StringGet(key);
Assert.Equal(workPerThread * threads, val);
Output.WriteLine("{2}: Time for {0} ops on {4} threads: {1}ms ({3}); ops/s: {5}",
threads * workPerThread, timeTaken.TotalMilliseconds, Me()
, preserveOrder ? "preserve order" : "any order", threads, (workPerThread * threads) / timeTaken.TotalSeconds);
Output.WriteLine("{2}: Time for {0} ops on {3} threads: {1}ms (any order); ops/s: {4}",
threads * workPerThread, timeTaken.TotalMilliseconds, Me(), threads, (workPerThread * threads) / timeTaken.TotalSeconds);
#if DEBUG
long newAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
long newWorkerCount = ConnectionMultiplexer.GetAsyncCompletionWorkerCount();
Output.WriteLine("ResultBox allocations: {0}; workers {1}", newAlloc - oldAlloc, newWorkerCount - oldWorkerCount);
Output.WriteLine("ResultBox allocations: {0}", newAlloc - oldAlloc);
Assert.True(newAlloc - oldAlloc <= 2 * threads, "number of box allocations");
#endif
}
}
[Theory]
[InlineData(true, 1)]
[InlineData(false, 1)]
[InlineData(true, 5)]
[InlineData(false, 5)]
public void MassiveBulkOpsFireAndForget(bool preserveOrder, int threads)
[InlineData(1)]
[InlineData(5)]
public void MassiveBulkOpsFireAndForget(int threads)
{
using (var muxer = Create(syncTimeout: 30000))
{
muxer.PreserveAsyncOrder = preserveOrder;
#if DEBUG
long oldAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
#endif
......@@ -125,9 +104,8 @@ public void MassiveBulkOpsFireAndForget(bool preserveOrder, int threads)
var val = (long)conn.StringGet(key);
Assert.Equal(perThread * threads, val);
Output.WriteLine("{2}: Time for {0} ops over {5} threads: {1:###,###}ms ({3}); ops/s: {4:###,###,##0}",
Output.WriteLine("{2}: Time for {0} ops over {4} threads: {1:###,###}ms (any order); ops/s: {3:###,###,##0}",
val, elapsed.TotalMilliseconds, Me(),
preserveOrder ? "preserve order" : "any order",
val / elapsed.TotalSeconds, threads);
#if DEBUG
long newAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
......
......@@ -11,10 +11,8 @@ public class PreserveOrder : TestBase
{
public PreserveOrder(ITestOutputHelper output) : base (output) { }
[Theory]
[InlineData(true)]
[InlineData(false)]
public void Execute(bool preserveAsyncOrder)
[Fact]
public void Execute()
{
using (var conn = Create())
{
......@@ -33,9 +31,8 @@ public void Execute(bool preserveAsyncOrder)
Thread.Sleep(1); // you kinda need to be slow, otherwise
// the pool will end up doing everything on one thread
});
conn.PreserveAsyncOrder = preserveAsyncOrder;
Output.WriteLine("");
Output.WriteLine("Sending ({0})...", preserveAsyncOrder ? "preserved order" : "any order");
Output.WriteLine("Sending (any order)...");
lock (received)
{
received.Clear();
......@@ -65,11 +62,9 @@ public void Execute(bool preserveAsyncOrder)
if (received[i] != i) wrongOrder++;
}
Output.WriteLine("Out of order: " + wrongOrder);
if (preserveAsyncOrder) Assert.Equal(0, wrongOrder);
else Assert.NotEqual(0, wrongOrder);
}
}
}
}
}
}
\ No newline at end of file
}
......@@ -42,23 +42,16 @@ public void ExplicitPublishMode()
}
[Theory]
[InlineData(true, null, false)]
[InlineData(false, null, false)]
[InlineData(true, "", false)]
[InlineData(false, "", false)]
[InlineData(true, "Foo:", false)]
[InlineData(false, "Foo:", false)]
[InlineData(true, null, true)]
[InlineData(false, null, true)]
[InlineData(true, "", true)]
[InlineData(false, "", true)]
[InlineData(true, "Foo:", true)]
[InlineData(false, "Foo:", true)]
public void TestBasicPubSub(bool preserveOrder, string channelPrefix, bool wildCard)
[InlineData(null, false)]
[InlineData("", false)]
[InlineData("Foo:", false)]
[InlineData(null, true)]
[InlineData("", true)]
[InlineData("Foo:", true)]
public void TestBasicPubSub(string channelPrefix, bool wildCard)
{
using (var muxer = Create(channelPrefix: channelPrefix))
{
muxer.PreserveAsyncOrder = preserveOrder;
var pub = GetAnyMaster(muxer);
var sub = muxer.GetSubscriber();
Ping(muxer, pub, sub);
......@@ -123,14 +116,11 @@ public void TestBasicPubSub(bool preserveOrder, string channelPrefix, bool wildC
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void TestBasicPubSubFireAndForget(bool preserveOrder)
[Fact]
public void TestBasicPubSubFireAndForget()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var pub = GetAnyMaster(muxer);
var sub = muxer.GetSubscriber();
......@@ -193,14 +183,11 @@ private static void Ping(ConnectionMultiplexer muxer, IServer pub, ISubscriber s
}
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void TestPatternPubSub(bool preserveOrder)
[Fact]
public void TestPatternPubSub()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
var pub = GetAnyMaster(muxer);
var sub = muxer.GetSubscriber();
......
......@@ -14,14 +14,11 @@ public class Secure : TestBase
public Secure(ITestOutputHelper output) : base (output) { }
[Theory]
[InlineData(true)]
[InlineData(false)]
public void MassiveBulkOpsFireAndForgetSecure(bool preserveOrder)
[Fact]
public void MassiveBulkOpsFireAndForgetSecure()
{
using (var muxer = Create())
{
muxer.PreserveAsyncOrder = preserveOrder;
#if DEBUG
long oldAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
#endif
......@@ -38,8 +35,7 @@ public void MassiveBulkOpsFireAndForgetSecure(bool preserveOrder)
int val = (int)conn.StringGet(key);
Assert.Equal(AsyncOpsQty, val);
watch.Stop();
Output.WriteLine("{2}: Time for {0} ops: {1}ms ({3}); ops/s: {4}", AsyncOpsQty, watch.ElapsedMilliseconds, Me(),
preserveOrder ? "preserve order" : "any order",
Output.WriteLine("{2}: Time for {0} ops: {1}ms (any order); ops/s: {3}", AsyncOpsQty, watch.ElapsedMilliseconds, Me(),
AsyncOpsQty / watch.Elapsed.TotalSeconds);
#if DEBUG
long newAlloc = ConnectionMultiplexer.GetResultBoxAllocationCount();
......
using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace StackExchange.Redis.Tests
{
public class TaskTests
{
#if DEBUG
[Theory]
[InlineData(SourceOrign.NewTCS)]
[InlineData(SourceOrign.Create)]
public void VerifyIsSyncSafe(SourceOrign origin)
{
var source = Create<int>(origin);
// Yes this looks stupid, but it's the proper pattern for how we statically init now
// ...and if we're dropping NET45 support, we can just nuke it all.
#if NET462
Assert.True(TaskSource.IsSyncSafe(source.Task));
#elif NETCOREAPP2_0
Assert.True(TaskSource.IsSyncSafe(source.Task));
#endif
}
private static TaskCompletionSource<T> Create<T>(SourceOrign origin)
{
switch (origin)
{
case SourceOrign.NewTCS: return new TaskCompletionSource<T>();
case SourceOrign.Create: return TaskSource.Create<T>(null);
default: throw new ArgumentOutOfRangeException(nameof(origin));
}
}
[Theory]
// regular framework behaviour: 2 out of 3 cause hijack
[InlineData(SourceOrign.NewTCS, AttachMode.ContinueWith, true)]
[InlineData(SourceOrign.NewTCS, AttachMode.ContinueWithExecSync, false)]
[InlineData(SourceOrign.NewTCS, AttachMode.Await, true)]
// Create is just a wrapper of ^^^; expect the same
[InlineData(SourceOrign.Create, AttachMode.ContinueWith, true)]
[InlineData(SourceOrign.Create, AttachMode.ContinueWithExecSync, false)]
[InlineData(SourceOrign.Create, AttachMode.Await, true)]
public void TestContinuationHijacking(SourceOrign origin, AttachMode attachMode, bool expectHijack)
{
TaskCompletionSource<int> source = Create<int>(origin);
int settingThread = Environment.CurrentManagedThreadId;
var state = new AwaitState();
state.Attach(source.Task, attachMode);
source.TrySetResult(123);
state.Wait(); // waits for the continuation to run
int from = state.Thread;
Assert.NotEqual(-1, from); // not set
if (expectHijack)
{
Assert.True(settingThread != from, $"expected hijack; didn't happen, Origin={settingThread}, Final={from}");
}
else
{
Assert.True(settingThread == from, $"setter was hijacked, Origin={settingThread}, Final={from}");
}
}
public enum SourceOrign
{
NewTCS,
Create
}
public enum AttachMode
{
ContinueWith,
ContinueWithExecSync,
Await
}
private class AwaitState
{
public int Thread => continuationThread;
private volatile int continuationThread = -1;
private readonly ManualResetEventSlim evt = new ManualResetEventSlim();
public void Wait()
{
if (!evt.Wait(5000)) throw new TimeoutException();
}
public void Attach(Task task, AttachMode attachMode)
{
switch (attachMode)
{
case AttachMode.ContinueWith:
task.ContinueWith(Continue);
break;
case AttachMode.ContinueWithExecSync:
task.ContinueWith(Continue, TaskContinuationOptions.ExecuteSynchronously);
break;
case AttachMode.Await:
DoAwait(task);
break;
default:
throw new ArgumentOutOfRangeException(nameof(attachMode));
}
}
private void Continue(Task task)
{
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
private async void DoAwait(Task task)
{
await task.ConfigureAwait(false);
continuationThread = Environment.CurrentManagedThreadId;
evt.Set();
}
}
#endif
}
}
......@@ -25,5 +25,6 @@
<ItemGroup>
<PackageReference Include="Pipelines.Sockets.Unofficial" Version="0.2.1-alpha.58" />
<PackageReference Include="System.Threading.Channels" Version="4.5.0" />
</ItemGroup>
</Project>
\ No newline at end of file
using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
/// <summary>
/// Represents a message that is broadcast via pub/sub
/// </summary>
public readonly struct ChannelMessage
{
internal ChannelMessage(RedisChannel channel, RedisValue value)
{
Channel = channel;
Value = value;
}
/// <summary>
/// The channel that the message was broadcast to
/// </summary>
public RedisChannel Channel { get; }
/// <summary>
/// The value that was broadcast
/// </summary>
public RedisValue Value { get; }
}
/// <summary>
/// Represents a message queue of ordered pub/sub notifications
/// </summary>
/// <remarks>To create a ChannelMessageQueue, use ISubscriber.Subscribe[Async](RedisKey)</remarks>
public sealed class ChannelMessageQueue
{
private readonly Channel<ChannelMessage> _channel;
private readonly RedisChannel _redisChannel;
private RedisSubscriber _parent;
/// <summary>
/// Indicates if all messages that will be received have been drained from this channel
/// </summary>
public bool IsCompleted { get; private set; }
internal ChannelMessageQueue(RedisChannel redisChannel, RedisSubscriber parent)
{
_redisChannel = redisChannel;
_parent = parent;
_channel = Channel.CreateUnbounded<ChannelMessage>(s_ChannelOptions);
_channel.Reader.Completion.ContinueWith(
(t, state) => ((ChannelMessageQueue)state).IsCompleted = true, this, TaskContinuationOptions.ExecuteSynchronously);
}
static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions
{
SingleWriter = true,
SingleReader = false,
AllowSynchronousContinuations = false,
};
internal void Subscribe(CommandFlags flags) => _parent.Subscribe(_redisChannel, HandleMessage, flags);
internal Task SubscribeAsync(CommandFlags flags) => _parent.SubscribeAsync(_redisChannel, HandleMessage, flags);
private void HandleMessage(RedisChannel channel, RedisValue value)
{
var writer = _channel.Writer;
if (channel.IsNull && value.IsNull) // see ForSyncShutdown
{
writer.TryComplete();
}
else
{
writer.TryWrite(new ChannelMessage(channel, value));
}
}
/// <summary>
/// Consume a message from the channel
/// </summary>
public ValueTask<ChannelMessage> ReadAsync(CancellationToken cancellationToken = default)
=> _channel.Reader.ReadAsync(cancellationToken);
/// <summary>
/// Attempt to synchronously consume a message from the channel
/// </summary>
public bool TryRead(out ChannelMessage item) => _channel.Reader.TryRead(out item);
private Delegate _onMessageHandler;
private void AssertOnMessage(Delegate handler)
{
if (handler == null) throw new ArgumentNullException(nameof(handler));
if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null)
throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed");
}
/// <summary>
/// Create a message loop that processes messages sequentially
/// </summary>
public void OnMessage(Action<RedisChannel, RedisValue> handler)
{
AssertOnMessage(handler);
ThreadPool.QueueUserWorkItem(
state => ((ChannelMessageQueue)state).OnMessageSyncImpl(), this);
}
private async void OnMessageSyncImpl()
{
var handler = (Action<RedisChannel, RedisValue>)_onMessageHandler;
while (!IsCompleted)
{
ChannelMessage next;
try { if(!TryRead(out next)) next = await ReadAsync(); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)
{
_parent.multiplexer?.OnInternalError(ex);
break;
}
try { handler.Invoke(next.Channel, next.Value); }
catch { } // matches MessageCompletable
}
}
/// <summary>
/// Create a message loop that processes messages sequentially
/// </summary>
public void OnMessage(Func<RedisChannel, RedisValue, Task> handler)
{
AssertOnMessage(handler);
ThreadPool.QueueUserWorkItem(
state => ((ChannelMessageQueue)state).OnMessageAsyncImpl(), this);
}
private async void OnMessageAsyncImpl()
{
var handler = (Func<RedisChannel, RedisValue, Task>)_onMessageHandler;
while (!IsCompleted)
{
ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync(); }
catch (ChannelClosedException) { break; } // expected
catch (Exception ex)
{
_parent.multiplexer?.OnInternalError(ex);
break;
}
try
{
var task = handler.Invoke(next.Channel, next.Value);
if (task != null) await task;
}
catch { } // matches MessageCompletable
}
}
internal void UnsubscribeImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{
var parent = _parent;
if (parent != null)
{
parent.UnsubscribeAsync(_redisChannel, HandleMessage, flags);
_parent = null;
_channel.Writer.TryComplete(error);
}
}
internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags flags = CommandFlags.None)
{
var parent = _parent;
if (parent != null)
{
await parent.UnsubscribeAsync(_redisChannel, HandleMessage, flags);
_parent = null;
_channel.Writer.TryComplete(error);
}
}
internal static bool IsOneOf(Action<RedisChannel, RedisValue> handler)
{
try
{
return handler != null && handler.Target is ChannelMessageQueue
&& handler.Method.Name == nameof(HandleMessage);
}
catch
{
return false;
}
}
/// <summary>
/// Stop receiving messages on this channel
/// </summary>
public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags);
/// <summary>
/// Stop receiving messages on this channel
/// </summary>
public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags);
}
}
......@@ -7,16 +7,10 @@ namespace StackExchange.Redis
{
internal sealed partial class CompletionManager
{
private static readonly WaitCallback processAsyncCompletionQueue = ProcessAsyncCompletionQueue,
anyOrderCompletionHandler = AnyOrderCompletionHandler;
private readonly Queue<ICompletable> asyncCompletionQueue = new Queue<ICompletable>();
private readonly ConnectionMultiplexer multiplexer;
private readonly string name;
private int activeAsyncWorkerThread = 0;
private long completedSync, completedAsync, failedAsync;
public CompletionManager(ConnectionMultiplexer multiplexer, string name)
{
......@@ -34,65 +28,20 @@ public void CompleteSyncOrAsync(ICompletable operation)
}
else
{
if (multiplexer.PreserveAsyncOrder)
{
multiplexer.Trace("Queueing for asynchronous completion", name);
bool startNewWorker;
lock (asyncCompletionQueue)
{
asyncCompletionQueue.Enqueue(operation);
startNewWorker = asyncCompletionQueue.Count == 1;
}
if (startNewWorker)
{
multiplexer.Trace("Starting new async completion worker", name);
OnCompletedAsync();
ThreadPool.QueueUserWorkItem(processAsyncCompletionQueue, this);
}
} else
{
multiplexer.Trace("Using thread-pool for asynchronous completion", name);
ThreadPool.QueueUserWorkItem(anyOrderCompletionHandler, operation);
Interlocked.Increment(ref completedAsync); // k, *technically* we haven't actually completed this yet, but: close enough
}
multiplexer.Trace("Using thread-pool for asynchronous completion", name);
multiplexer.SocketManager.ScheduleTask(s_AnyOrderCompletionHandler, operation);
Interlocked.Increment(ref completedAsync); // k, *technically* we haven't actually completed this yet, but: close enough
}
}
internal void GetCounters(ConnectionCounters counters)
{
lock (asyncCompletionQueue)
{
counters.ResponsesAwaitingAsyncCompletion = asyncCompletionQueue.Count;
}
counters.CompletedSynchronously = Interlocked.Read(ref completedSync);
counters.CompletedAsynchronously = Interlocked.Read(ref completedAsync);
counters.FailedAsynchronously = Interlocked.Read(ref failedAsync);
}
internal int GetOutstandingCount()
{
lock(asyncCompletionQueue)
{
return asyncCompletionQueue.Count;
}
}
internal void GetStormLog(StringBuilder sb)
{
lock(asyncCompletionQueue)
{
if (asyncCompletionQueue.Count == 0) return;
sb.Append("Response awaiting completion: ").Append(asyncCompletionQueue.Count).AppendLine();
int total = 0;
foreach(var item in asyncCompletionQueue)
{
if (++total >= 500) break;
item.AppendStormLog(sb);
sb.AppendLine();
}
}
}
private static readonly Action<object> s_AnyOrderCompletionHandler = AnyOrderCompletionHandler;
private static void AnyOrderCompletionHandler(object state)
{
try
......@@ -104,88 +53,6 @@ private static void AnyOrderCompletionHandler(object state)
{
ConnectionMultiplexer.TraceWithoutContext("Async completion error: " + ex.Message);
}
}
private static void ProcessAsyncCompletionQueue(object state)
{
((CompletionManager)state).ProcessAsyncCompletionQueueImpl();
}
partial void OnCompletedAsync();
private void ProcessAsyncCompletionQueueImpl()
{
int currentThread = Environment.CurrentManagedThreadId;
try
{
while (Interlocked.CompareExchange(ref activeAsyncWorkerThread, currentThread, 0) != 0)
{
// if we don't win the lock, check whether there is still work; if there is we
// need to retry to prevent a nasty race condition
lock(asyncCompletionQueue)
{
if (asyncCompletionQueue.Count == 0) return; // another thread drained it; can exit
}
Thread.Sleep(1);
}
int total = 0;
while (true)
{
ICompletable next;
lock (asyncCompletionQueue)
{
next = asyncCompletionQueue.Count == 0 ? null
: asyncCompletionQueue.Dequeue();
}
if (next == null)
{
// give it a moment and try again, noting that we might lose the battle
// when we pause
Interlocked.CompareExchange(ref activeAsyncWorkerThread, 0, currentThread);
if (SpinWait() && Interlocked.CompareExchange(ref activeAsyncWorkerThread, currentThread, 0) == 0)
{
// we paused, and we got the lock back; anything else?
lock (asyncCompletionQueue)
{
next = asyncCompletionQueue.Count == 0 ? null
: asyncCompletionQueue.Dequeue();
}
}
}
if (next == null) break; // nothing to do <===== exit point
try
{
multiplexer.Trace("Completing async (ordered): " + next, name);
next.TryComplete(true);
Interlocked.Increment(ref completedAsync);
}
catch (Exception ex)
{
multiplexer.Trace("Async completion error: " + ex.Message, name);
Interlocked.Increment(ref failedAsync);
}
total++;
}
multiplexer.Trace("Async completion worker processed " + total + " operations", name);
}
finally
{
Interlocked.CompareExchange(ref activeAsyncWorkerThread, 0, currentThread);
}
}
private bool SpinWait()
{
var sw = new SpinWait();
byte maxSpins = 128;
do
{
if (sw.NextSpinWillYield)
return true;
maxSpins--;
}
while (maxSpins > 0);
return false;
}
}
}
}
......@@ -124,7 +124,7 @@ public static string TryNormalize(string value)
}
}
private bool? allowAdmin, abortOnConnectFail, highPrioritySocketThreads, resolveDns, ssl, preserveAsyncOrder;
private bool? allowAdmin, abortOnConnectFail, highPrioritySocketThreads, resolveDns, ssl;
private string tieBreaker, sslHost, configChannel;
......@@ -258,7 +258,12 @@ public int ConnectTimeout
/// <summary>
/// Specifies whether asynchronous operations should be invoked in a way that guarantees their original delivery order
/// </summary>
public bool PreserveAsyncOrder { get { return preserveAsyncOrder.GetValueOrDefault(true); } set { preserveAsyncOrder = value; } }
[Obsolete("Not supported; if you require ordered pub/sub, please see " + nameof(ChannelMessageQueue), false)]
public bool PreserveAsyncOrder
{
get { return false; }
set { }
}
/// <summary>
/// Type of proxy to use (if any); for example Proxy.Twemproxy.
......@@ -395,7 +400,6 @@ public ConfigurationOptions Clone()
responseTimeout = responseTimeout,
DefaultDatabase = DefaultDatabase,
ReconnectRetryPolicy = reconnectRetryPolicy,
preserveAsyncOrder = preserveAsyncOrder,
SslProtocols = SslProtocols,
};
foreach (var item in EndPoints)
......@@ -453,7 +457,6 @@ public string ToString(bool includePassword)
Append(sb, OptionKeys.ConfigCheckSeconds, configCheckSeconds);
Append(sb, OptionKeys.ResponseTimeout, responseTimeout);
Append(sb, OptionKeys.DefaultDatabase, DefaultDatabase);
Append(sb, OptionKeys.PreserveAsyncOrder, preserveAsyncOrder);
commandMap?.AppendDeltas(sb);
return sb.ToString();
}
......@@ -532,7 +535,7 @@ private void Clear()
{
ClientName = ServiceName = Password = tieBreaker = sslHost = configChannel = null;
keepAlive = syncTimeout = connectTimeout = writeBuffer = connectRetry = configCheckSeconds = DefaultDatabase = null;
allowAdmin = abortOnConnectFail = highPrioritySocketThreads = resolveDns = ssl = preserveAsyncOrder = null;
allowAdmin = abortOnConnectFail = highPrioritySocketThreads = resolveDns = ssl = null;
defaultVersion = null;
EndPoints.Clear();
commandMap = null;
......@@ -646,7 +649,6 @@ private void DoParse(string configuration, bool ignoreUnknown)
DefaultDatabase = OptionKeys.ParseInt32(key, value);
break;
case OptionKeys.PreserveAsyncOrder:
PreserveAsyncOrder = OptionKeys.ParseBoolean(key, value);
break;
case OptionKeys.SslProtocols:
SslProtocols = OptionKeys.ParseSslProtocols(key, value);
......
......@@ -899,7 +899,6 @@ private ConnectionMultiplexer(ConfigurationOptions configuration)
map.AssertAvailable(RedisCommand.EXISTS);
}
PreserveAsyncOrder = configuration.PreserveAsyncOrder;
TimeoutMilliseconds = configuration.SyncTimeout;
OnCreateReaderWriter(configuration);
......@@ -1810,7 +1809,12 @@ public override string ToString()
/// <summary>
/// Gets or sets whether asynchronous operations should be invoked in a way that guarantees their original delivery order
/// </summary>
public bool PreserveAsyncOrder { get; set; }
[Obsolete("Not supported; if you require ordered pub/sub, please see " + nameof(ChannelMessageQueue), false)]
public bool PreserveAsyncOrder
{
get => false;
set { }
}
/// <summary>
/// Indicates whether any servers are connected
......@@ -2037,11 +2041,9 @@ void add(string lk, string sk, string v)
}
}
int queue = server.GetOutstandingCount(message.Command, out int inst, out int qs, out int qc, out int @in);
server.GetOutstandingCount(message.Command, out int inst, out int qs, out int @in);
add("Instantaneous", "inst", inst.ToString());
add("Queue-Length", "queue", queue.ToString());
add("Queue-Awaiting-Response", "qs", qs.ToString());
add("Queue-Completion-Outstanding", "qc", qc.ToString());
add("Inbound-Bytes", "in", @in.ToString());
add("Manager", "mgr", SocketManager?.GetState());
......@@ -2067,7 +2069,7 @@ void add(string lk, string sk, string v)
sb.Append(timeoutHelpLink);
sb.Append(")");
errMessage = sb.ToString();
if (StormLogThreshold >= 0 && queue >= StormLogThreshold && Interlocked.CompareExchange(ref haveStormLog, 1, 0) == 0)
if (StormLogThreshold >= 0 && qs >= StormLogThreshold && Interlocked.CompareExchange(ref haveStormLog, 1, 0) == 0)
{
var log = server.GetStormLog(message.Command);
if (string.IsNullOrWhiteSpace(log)) Interlocked.Exchange(ref haveStormLog, 0);
......
......@@ -129,17 +129,6 @@ void IServer.Hang(TimeSpan duration, CommandFlags flags)
}
}
internal partial class CompletionManager
{
private static long asyncCompletionWorkerCount;
#pragma warning disable RCS1047 // Non-asynchronous method name should not end with 'Async'.
partial void OnCompletedAsync() => Interlocked.Increment(ref asyncCompletionWorkerCount);
#pragma warning restore RCS1047 // Non-asynchronous method name should not end with 'Async'.
internal static long GetAsyncCompletionWorkerCount() => Interlocked.Read(ref asyncCompletionWorkerCount);
}
public partial class ConnectionMultiplexer
{
/// <summary>
......@@ -147,11 +136,6 @@ public partial class ConnectionMultiplexer
/// </summary>
public static long GetResultBoxAllocationCount() => ResultBox.GetAllocationCount();
/// <summary>
/// Gets how many async completion workers were queueud
/// </summary>
public static long GetAsyncCompletionWorkerCount() => CompletionManager.GetAsyncCompletionWorkerCount();
private volatile bool allowConnect = true,
ignoreConnect = false;
......
using System;
using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
......@@ -44,6 +44,7 @@ public interface IConnectionMultiplexer
/// <summary>
/// Gets or sets whether asynchronous operations should be invoked in a way that guarantees their original delivery order
/// </summary>
[Obsolete("Not supported; if you require ordered pub/sub, please see " + nameof(ChannelMessageQueue), false)]
bool PreserveAsyncOrder { get; set; }
/// <summary>
......@@ -282,4 +283,4 @@ public interface IConnectionMultiplexer
/// <returns>The number of instances known to have received the message (however, the actual number can be higher)</returns>
Task<long> PublishReconfigureAsync(CommandFlags flags = CommandFlags.None);
}
}
\ No newline at end of file
}
using System;
using System;
using System.Net;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace StackExchange.Redis
......@@ -64,6 +65,16 @@ public interface ISubscriber : IRedis
/// <remarks>https://redis.io/commands/psubscribe</remarks>
void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Subscribe to perform some operation when a message to the preferred/active node is broadcast, as a queue that guarantees ordered handling.
/// </summary>
/// <param name="channel">The redis channel to subscribe to.</param>
/// <param name="flags">The command flags to use.</param>
/// <returns>A channel that represents this source</returns>
/// <remarks>https://redis.io/commands/subscribe</remarks>
/// <remarks>https://redis.io/commands/psubscribe</remarks>
ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Subscribe to perform some operation when a change to the preferred/active node is broadcast.
/// </summary>
......@@ -74,6 +85,16 @@ public interface ISubscriber : IRedis
/// <remarks>https://redis.io/commands/psubscribe</remarks>
Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Subscribe to perform some operation when a change to the preferred/active node is broadcast, as a channel.
/// </summary>
/// <param name="channel">The redis channel to subscribe to.</param>
/// <param name="flags">The command flags to use.</param>
/// <returns>A channel that represents this source</returns>
/// <remarks>https://redis.io/commands/subscribe</remarks>
/// <remarks>https://redis.io/commands/psubscribe</remarks>
Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Inidicate to which redis server we are actively subscribed for a given channel; returns null if
/// the channel is not actively subscribed
......
......@@ -7,35 +7,50 @@ internal sealed class MessageCompletable : ICompletable
{
private readonly RedisChannel channel;
private readonly Action<RedisChannel, RedisValue> handler;
private readonly Action<RedisChannel, RedisValue> syncHandler, asyncHandler;
private readonly RedisValue message;
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> handler)
public MessageCompletable(RedisChannel channel, RedisValue message, Action<RedisChannel, RedisValue> syncHandler, Action<RedisChannel, RedisValue> asyncHandler)
{
this.channel = channel;
this.message = message;
this.handler = handler;
this.syncHandler = syncHandler;
this.asyncHandler = asyncHandler;
}
public override string ToString() => (string)channel;
public bool TryComplete(bool isAsync)
{
if (handler == null) return true;
if (isAsync)
{
ConnectionMultiplexer.TraceWithoutContext("Invoking...: " + (string)channel, "Subscription");
foreach(Action<RedisChannel, RedisValue> sub in handler.GetInvocationList())
if (asyncHandler != null)
{
try { sub.Invoke(channel, message); }
catch { }
ConnectionMultiplexer.TraceWithoutContext("Invoking (async)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in asyncHandler.GetInvocationList())
{
try { sub.Invoke(channel, message); }
catch { }
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (async)", "Subscription");
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete", "Subscription");
return true;
}
// needs to be called async (unless there is nothing to do!)
return false;
else
{
if (syncHandler != null)
{
ConnectionMultiplexer.TraceWithoutContext("Invoking (sync)...: " + (string)channel, "Subscription");
foreach (Action<RedisChannel, RedisValue> sub in syncHandler.GetInvocationList())
{
try { sub.Invoke(channel, message); }
catch { }
}
ConnectionMultiplexer.TraceWithoutContext("Invoke complete (sync)", "Subscription");
}
return asyncHandler == null; // anything async to do?
}
}
void ICompletable.AppendStormLog(StringBuilder sb)
......
......@@ -177,7 +177,7 @@ internal void GetCounters(ConnectionCounters counters)
physical?.GetCounters(counters);
}
internal int GetOutstandingCount(out int inst, out int qs, out int qc, out int @in)
internal void GetOutstandingCount(out int inst, out int qs, out int @in)
{// defined as: PendingUnsentItems + SentItemsAwaitingResponse + ResponsesAwaitingAsyncCompletion
inst = (int)(Interlocked.Read(ref operationCount) - Interlocked.Read(ref profileLastLog));
var tmp = physical;
......@@ -190,8 +190,6 @@ internal int GetOutstandingCount(out int inst, out int qs, out int qc, out int @
qs = tmp.GetSentAwaitingResponseCount();
@in = tmp.GetAvailableInboundBytes();
}
qc = completionManager.GetOutstandingCount();
return qs + qc;
}
internal string GetStormLog()
......@@ -200,7 +198,6 @@ internal string GetStormLog()
.Append(" at ").Append(DateTime.UtcNow)
.AppendLine().AppendLine();
physical?.GetStormLog(sb);
completionManager.GetStormLog(sb);
sb.Append("Circular op-count snapshot:");
AppendProfile(sb);
sb.AppendLine();
......
......@@ -31,15 +31,16 @@ internal Task AddSubscription(RedisChannel channel, Action<RedisChannel, RedisVa
{
if (handler != null)
{
bool asAsync = !ChannelMessageQueue.IsOneOf(handler);
lock (subscriptions)
{
if (subscriptions.TryGetValue(channel, out Subscription sub))
{
sub.Add(handler);
sub.Add(asAsync, handler);
}
else
{
sub = new Subscription(handler);
sub = new Subscription(asAsync, handler);
subscriptions.Add(channel, sub);
var task = sub.SubscribeToServer(this, channel, flags, asyncState, false);
if (task != null) return task;
......@@ -84,7 +85,11 @@ internal Task RemoveAllSubscriptions(CommandFlags flags, object asyncState)
{
foreach (var pair in subscriptions)
{
pair.Value.Remove(null); // always wipes
var msg = pair.Value.ForSyncShutdown();
if(msg != null) UnprocessableCompletionManager?.CompleteSyncOrAsync(msg);
pair.Value.Remove(true, null);
pair.Value.Remove(false, null);
var task = pair.Value.UnsubscribeFromServer(pair.Key, flags, asyncState, false);
if (task != null) last = task;
}
......@@ -97,7 +102,8 @@ internal Task RemoveSubscription(RedisChannel channel, Action<RedisChannel, Redi
{
lock (subscriptions)
{
if (subscriptions.TryGetValue(channel, out Subscription sub) && sub.Remove(handler))
bool asAsync = ChannelMessageQueue.IsOneOf(handler);
if (subscriptions.TryGetValue(channel, out Subscription sub) && sub.Remove(asAsync, handler))
{
subscriptions.Remove(channel);
var task = sub.UnsubscribeFromServer(channel, flags, asyncState, false);
......@@ -143,30 +149,46 @@ internal long ValidateSubscriptions()
private sealed class Subscription
{
private Action<RedisChannel, RedisValue> handler;
private Action<RedisChannel, RedisValue> _asyncHandler, _syncHandler;
private ServerEndPoint owner;
public Subscription(Action<RedisChannel, RedisValue> value) => handler = value;
public Subscription(bool asAsync, Action<RedisChannel, RedisValue> value)
{
if (asAsync) _asyncHandler = value;
else _syncHandler = value;
}
public void Add(Action<RedisChannel, RedisValue> value) => handler += value;
public void Add(bool asAsync, Action<RedisChannel, RedisValue> value)
{
if (asAsync) _asyncHandler += value;
else _syncHandler += value;
}
public ICompletable ForSyncShutdown()
{
var syncHandler = _syncHandler;
return syncHandler == null ? null : new MessageCompletable(default, default, syncHandler, null);
}
public ICompletable ForInvoke(RedisChannel channel, RedisValue message)
{
var tmp = handler;
return tmp == null ? null : new MessageCompletable(channel, message, tmp);
var syncHandler = _syncHandler;
var asyncHandler = _asyncHandler;
return (syncHandler == null && asyncHandler == null) ? null : new MessageCompletable(channel, message, syncHandler, asyncHandler);
}
public bool Remove(Action<RedisChannel, RedisValue> value)
public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value)
{
if (value == null)
{ // treat as blanket wipe
handler = null;
return true;
if (asAsync) _asyncHandler = null;
else _syncHandler = null;
}
else
{
return (handler -= value) == null;
if (asAsync) _asyncHandler -= value;
else _syncHandler -= value;
}
return _syncHandler == null && _asyncHandler == null;
}
public Task SubscribeToServer(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
......@@ -287,12 +309,26 @@ public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> han
if ((flags & CommandFlags.FireAndForget) == 0) Wait(task);
}
public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
var c = new ChannelMessageQueue(channel, this);
c.Subscribe(flags);
return c;
}
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
{
if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel));
return multiplexer.AddSubscription(channel, handler, flags, asyncState);
}
public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
var c = new ChannelMessageQueue(channel, this);
await c.SubscribeAsync(flags);
return c;
}
public EndPoint SubscribedEndpoint(RedisChannel channel)
{
var server = multiplexer.GetSubscribedServer(channel);
......
......@@ -97,7 +97,7 @@ public override bool TryComplete(bool isAsync)
{
if (stateOrCompletionSource is TaskCompletionSource<T> tcs)
{
if (isAsync || TaskSource.IsSyncSafe(tcs.Task))
if (isAsync)
{
UnwrapAndRecycle(this, true, out T val, out Exception ex);
......
......@@ -367,15 +367,14 @@ internal ServerCounters GetCounters()
return counters;
}
internal int GetOutstandingCount(RedisCommand command, out int inst, out int qs, out int qc, out int @in)
internal void GetOutstandingCount(RedisCommand command, out int inst, out int qs, out int @in)
{
var bridge = GetBridge(command, false);
if (bridge == null)
{
return inst = qs = qc = @in = 0;
inst = qs = @in = 0;
}
return bridge.GetOutstandingCount(out inst, out qs, out qc, out @in);
bridge.GetOutstandingCount(out inst, out qs, out @in);
}
internal string GetProfile()
......
......@@ -158,24 +158,27 @@ private SocketManager(string name, bool useHighPrioritySocketThreads, int minThr
const long Receive_ResumeWriterThreshold = 3L * 1024 * 1024 * 1024;
var defaultPipeOptions = PipeOptions.Default;
_scheduler = new DedicatedThreadPoolPipeScheduler(name,
_schedulerPool = new DedicatedThreadPoolPipeScheduler(name + ":IO",
minWorkers: minThreads, maxWorkers: maxThreads,
priority: useHighPrioritySocketThreads ? ThreadPriority.AboveNormal : ThreadPriority.Normal);
SendPipeOptions = new PipeOptions(
defaultPipeOptions.Pool, _scheduler, _scheduler,
defaultPipeOptions.Pool, _schedulerPool, _schedulerPool,
pauseWriterThreshold: defaultPipeOptions.PauseWriterThreshold,
resumeWriterThreshold: defaultPipeOptions.ResumeWriterThreshold,
minimumSegmentSize: Math.Max(defaultPipeOptions.MinimumSegmentSize, MINIMUM_SEGMENT_SIZE),
useSynchronizationContext: false);
ReceivePipeOptions = new PipeOptions(
defaultPipeOptions.Pool, _scheduler, _scheduler,
defaultPipeOptions.Pool, _schedulerPool, _schedulerPool,
pauseWriterThreshold: Receive_PauseWriterThreshold,
resumeWriterThreshold: Receive_ResumeWriterThreshold,
minimumSegmentSize: Math.Max(defaultPipeOptions.MinimumSegmentSize, MINIMUM_SEGMENT_SIZE),
useSynchronizationContext: false);
_completionPool = new DedicatedThreadPoolPipeScheduler(name + ":Completion",
minWorkers: 1, maxWorkers: maxThreads, useThreadPoolQueueLength: 1);
}
private DedicatedThreadPoolPipeScheduler _scheduler;
private DedicatedThreadPoolPipeScheduler _schedulerPool, _completionPool;
internal readonly PipeOptions SendPipeOptions, ReceivePipeOptions;
private enum CallbackOperation
......@@ -194,8 +197,10 @@ private void Dispose(bool disposing)
// note: the scheduler *can't* be collected by itself - there will
// be threads, and those threads will be rooting the DedicatedThreadPool;
// but: we can lend a hand! We need to do this even in the finalizer
try { _scheduler?.Dispose(); } catch { }
_scheduler = null;
try { _schedulerPool?.Dispose(); } catch { }
try { _completionPool?.Dispose(); } catch { }
_schedulerPool = null;
_completionPool = null;
if (disposing)
{
GC.SuppressFinalize(this);
......@@ -354,8 +359,11 @@ private void Shutdown(Socket socket)
internal string GetState()
{
var s = _scheduler;
var s = _schedulerPool;
return s == null ? null : $"{s.BusyCount} of {s.WorkerCount} busy ({s.MaxWorkerCount} max)";
}
internal void ScheduleTask(Action<object> action, object state)
=> _completionPool.Schedule(action, state);
}
}
......@@ -5,86 +5,14 @@
namespace StackExchange.Redis
{
/// <summary>
/// We want to prevent callers hijacking the reader thread; this is a bit nasty, but works;
/// see https://stackoverflow.com/a/22588431/23354 for more information; a huge
/// thanks to Eli Arbel for spotting this (even though it is pure evil; it is *my kind of evil*)
/// </summary>
internal static class TaskSource
{
// on .NET < 4.6, it was possible to have threads hijacked; this is no longer a problem in 4.6 and core-clr 5,
// thanks to the new TaskCreationOptions.RunContinuationsAsynchronously, however we still need to be a little
// "test and react", as we could be targeting 4.5 but running on a 4.6 machine, in which case *it can still
// work the magic* (thanks to over-the-top install)
/// <summary>
/// Indicates whether the specified task will not hijack threads when results are set
/// </summary>
public static readonly Func<Task, bool> IsSyncSafe;
static TaskSource()
{
try
{
Type taskType = typeof(Task);
FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{
var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
//var hasContinuation = il.DefineLabel();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
// check if null
il.Emit(OpCodes.Brtrue_S, nonNull);
il.MarkLabel(goodReturn);
il.Emit(OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Ret);
// check if is a SetOnInvokeMres - if so, we're OK
il.MarkLabel(nonNull);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Isinst, safeScenario);
il.Emit(OpCodes.Brtrue_S, goodReturn);
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));
// and test them (check for an exception etc)
var tcs = new TaskCompletionSource<int>();
bool expectTrue = IsSyncSafe(tcs.Task);
tcs.Task.ContinueWith(delegate { });
bool expectFalse = IsSyncSafe(tcs.Task);
tcs.SetResult(0);
if (!expectTrue || expectFalse)
{
// revert to not trusting /them
IsSyncSafe = null;
}
}
}
catch (Exception)
{
IsSyncSafe = null;
}
if (IsSyncSafe == null)
{
IsSyncSafe = _ => false; // assume: not
}
}
/// <summary>
/// Create a new TaskCompletion source
/// </summary>
/// <typeparam name="T">The type for the created <see cref="TaskCompletionSource{TResult}"/>.</typeparam>
/// <param name="asyncState">The state for the created <see cref="TaskCompletionSource{TResult}"/>.</param>
public static TaskCompletionSource<T> Create<T>(object asyncState)
{
return new TaskCompletionSource<T>(asyncState, TaskCreationOptions.None);
}
=> new TaskCompletionSource<T>(asyncState, TaskCreationOptions.None);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment