Unverified Commit 9a213f91 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Result box simplify (#1064)

* simplify the whole result-box/TaskCompletionSource mess with the realization that the TCS *can be* the result-box, and simple (non-TCS) boxes can be [ThreadStatic]

* rev Pipelines.Sockets.Unofficial (removed AwaitableLockToken)
parent 68739410
 
Microsoft Visual Studio Solution File, Format Version 12.00 Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15 # Visual Studio Version 16
VisualStudioVersion = 15.0.26823.1 VisualStudioVersion = 16.0.28531.58
MinimumVisualStudioVersion = 10.0.40219.1 MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{3AD17044-6BFF-4750-9AC2-2CA466375F2A}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{3AD17044-6BFF-4750-9AC2-2CA466375F2A}"
ProjectSection(SolutionItems) = preProject ProjectSection(SolutionItems) = preProject
......
...@@ -309,7 +309,7 @@ public static Condition StringNotEqual(RedisKey key, RedisValue value) ...@@ -309,7 +309,7 @@ public static Condition StringNotEqual(RedisKey key, RedisValue value)
internal abstract void CheckCommands(CommandMap commandMap); internal abstract void CheckCommands(CommandMap commandMap);
internal abstract IEnumerable<Message> CreateMessages(int db, ResultBox resultBox); internal abstract IEnumerable<Message> CreateMessages(int db, IResultBox resultBox);
internal abstract int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy); internal abstract int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy);
internal abstract bool TryValidate(in RawResult result, out bool value); internal abstract bool TryValidate(in RawResult result, out bool value);
...@@ -439,7 +439,7 @@ public override string ToString() ...@@ -439,7 +439,7 @@ public override string ToString()
internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(cmd); internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(cmd);
internal override IEnumerable<Message> CreateMessages(int db, ResultBox resultBox) internal override IEnumerable<Message> CreateMessages(int db, IResultBox resultBox)
{ {
yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key); yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key);
...@@ -519,7 +519,7 @@ public override string ToString() ...@@ -519,7 +519,7 @@ public override string ToString()
internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(cmd); internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(cmd);
internal sealed override IEnumerable<Message> CreateMessages(int db, ResultBox resultBox) internal sealed override IEnumerable<Message> CreateMessages(int db, IResultBox resultBox)
{ {
yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key); yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key);
...@@ -601,7 +601,7 @@ internal override void CheckCommands(CommandMap commandMap) ...@@ -601,7 +601,7 @@ internal override void CheckCommands(CommandMap commandMap)
commandMap.AssertAvailable(RedisCommand.LINDEX); commandMap.AssertAvailable(RedisCommand.LINDEX);
} }
internal sealed override IEnumerable<Message> CreateMessages(int db, ResultBox resultBox) internal sealed override IEnumerable<Message> CreateMessages(int db, IResultBox resultBox)
{ {
yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key); yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key);
...@@ -700,7 +700,7 @@ internal override void CheckCommands(CommandMap commandMap) ...@@ -700,7 +700,7 @@ internal override void CheckCommands(CommandMap commandMap)
commandMap.AssertAvailable(cmd); commandMap.AssertAvailable(cmd);
} }
internal sealed override IEnumerable<Message> CreateMessages(int db, ResultBox resultBox) internal sealed override IEnumerable<Message> CreateMessages(int db, IResultBox resultBox)
{ {
yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key); yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key);
...@@ -763,7 +763,7 @@ public override string ToString() ...@@ -763,7 +763,7 @@ public override string ToString()
internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(RedisCommand.ZCOUNT); internal override void CheckCommands(CommandMap commandMap) => commandMap.AssertAvailable(RedisCommand.ZCOUNT);
internal sealed override IEnumerable<Message> CreateMessages(int db, ResultBox resultBox) internal sealed override IEnumerable<Message> CreateMessages(int db, IResultBox resultBox)
{ {
yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key); yield return Message.Create(db, CommandFlags.None, RedisCommand.WATCH, key);
...@@ -799,14 +799,14 @@ public sealed class ConditionResult ...@@ -799,14 +799,14 @@ public sealed class ConditionResult
{ {
internal readonly Condition Condition; internal readonly Condition Condition;
private ResultBox<bool> resultBox; private IResultBox<bool> resultBox;
private volatile bool wasSatisfied; private volatile bool wasSatisfied;
internal ConditionResult(Condition condition) internal ConditionResult(Condition condition)
{ {
Condition = condition; Condition = condition;
resultBox = ResultBox<bool>.Get(condition); resultBox = SimpleResultBox<bool>.Create();
} }
/// <summary> /// <summary>
...@@ -816,12 +816,12 @@ internal ConditionResult(Condition condition) ...@@ -816,12 +816,12 @@ internal ConditionResult(Condition condition)
internal IEnumerable<Message> CreateMessages(int db) => Condition.CreateMessages(db, resultBox); internal IEnumerable<Message> CreateMessages(int db) => Condition.CreateMessages(db, resultBox);
internal ResultBox<bool> GetBox() { return resultBox; } internal IResultBox<bool> GetBox() { return resultBox; }
internal bool UnwrapBox() internal bool UnwrapBox()
{ {
if (resultBox != null) if (resultBox != null)
{ {
ResultBox<bool>.UnwrapAndRecycle(resultBox, false, out bool val, out Exception ex); bool val = resultBox.GetResult(out var ex);
resultBox = null; resultBox = null;
wasSatisfied = ex == null && val; wasSatisfied = ex == null && val;
} }
......
...@@ -1910,7 +1910,7 @@ internal ServerEndPoint SelectServer(RedisCommand command, CommandFlags flags, i ...@@ -1910,7 +1910,7 @@ internal ServerEndPoint SelectServer(RedisCommand command, CommandFlags flags, i
return ServerSelectionStrategy.Select(command, key, flags); return ServerSelectionStrategy.Select(command, key, flags);
} }
private bool PrepareToPushMessageToBridge<T>(Message message, ResultProcessor<T> processor, ResultBox<T> resultBox, ref ServerEndPoint server) private bool PrepareToPushMessageToBridge<T>(Message message, ResultProcessor<T> processor, IResultBox<T> resultBox, ref ServerEndPoint server)
{ {
message.SetSource(processor, resultBox); message.SetSource(processor, resultBox);
...@@ -1964,12 +1964,12 @@ private bool PrepareToPushMessageToBridge<T>(Message message, ResultProcessor<T> ...@@ -1964,12 +1964,12 @@ private bool PrepareToPushMessageToBridge<T>(Message message, ResultProcessor<T>
Trace("No server or server unavailable - aborting: " + message); Trace("No server or server unavailable - aborting: " + message);
return false; return false;
} }
private ValueTask<WriteResult> TryPushMessageToBridgeAsync<T>(Message message, ResultProcessor<T> processor, ResultBox<T> resultBox, ref ServerEndPoint server) private ValueTask<WriteResult> TryPushMessageToBridgeAsync<T>(Message message, ResultProcessor<T> processor, IResultBox<T> resultBox, ref ServerEndPoint server)
=> PrepareToPushMessageToBridge(message, processor, resultBox, ref server) ? server.TryWriteAsync(message) : new ValueTask<WriteResult>(WriteResult.NoConnectionAvailable); => PrepareToPushMessageToBridge(message, processor, resultBox, ref server) ? server.TryWriteAsync(message) : new ValueTask<WriteResult>(WriteResult.NoConnectionAvailable);
[Obsolete("prefer async")] [Obsolete("prefer async")]
#pragma warning disable CS0618 #pragma warning disable CS0618
private WriteResult TryPushMessageToBridgeSync<T>(Message message, ResultProcessor<T> processor, ResultBox<T> resultBox, ref ServerEndPoint server) private WriteResult TryPushMessageToBridgeSync<T>(Message message, ResultProcessor<T> processor, IResultBox<T> resultBox, ref ServerEndPoint server)
=> PrepareToPushMessageToBridge(message, processor, resultBox, ref server) ? server.TryWriteSync(message) : WriteResult.NoConnectionAvailable; => PrepareToPushMessageToBridge(message, processor, resultBox, ref server) ? server.TryWriteSync(message) : WriteResult.NoConnectionAvailable;
#pragma warning restore CS0618 #pragma warning restore CS0618
...@@ -2127,11 +2127,10 @@ internal Task<T> ExecuteAsyncImpl<T>(Message message, ResultProcessor<T> process ...@@ -2127,11 +2127,10 @@ internal Task<T> ExecuteAsyncImpl<T>(Message message, ResultProcessor<T> process
} }
TaskCompletionSource<T> tcs = null; TaskCompletionSource<T> tcs = null;
ResultBox<T> source = null; IResultBox<T> source = null;
if (!message.IsFireAndForget) if (!message.IsFireAndForget)
{ {
tcs = TaskSource.Create<T>(state); source = TaskResultBox<T>.Create(out tcs, state);
source = ResultBox<T>.Get(tcs);
} }
var write = TryPushMessageToBridgeAsync(message, processor, source, ref server); var write = TryPushMessageToBridgeAsync(message, processor, source, ref server);
if (!write.IsCompletedSuccessfully) return ExecuteAsyncImpl_Awaited<T>(this, write, tcs, message, server); if (!write.IsCompletedSuccessfully) return ExecuteAsyncImpl_Awaited<T>(this, write, tcs, message, server);
...@@ -2231,7 +2230,7 @@ internal T ExecuteSyncImpl<T>(Message message, ResultProcessor<T> processor, Ser ...@@ -2231,7 +2230,7 @@ internal T ExecuteSyncImpl<T>(Message message, ResultProcessor<T> processor, Ser
} }
else else
{ {
var source = ResultBox<T>.Get(null); var source = SimpleResultBox<T>.Get();
lock (source) lock (source)
{ {
...@@ -2256,7 +2255,7 @@ internal T ExecuteSyncImpl<T>(Message message, ResultProcessor<T> processor, Ser ...@@ -2256,7 +2255,7 @@ internal T ExecuteSyncImpl<T>(Message message, ResultProcessor<T> processor, Ser
} }
} }
// snapshot these so that we can recycle the box // snapshot these so that we can recycle the box
ResultBox<T>.UnwrapAndRecycle(source, true, out T val, out Exception ex); // now that we aren't locking it... var val = source.GetResult(out var ex, canRecycle: true); // now that we aren't locking it...
if (ex != null) throw ex; if (ex != null) throw ex;
Trace(message + " received " + val); Trace(message + " received " + val);
return val; return val;
......
...@@ -75,7 +75,7 @@ internal abstract class Message : ICompletable ...@@ -75,7 +75,7 @@ internal abstract class Message : ICompletable
| CommandFlags.FireAndForget | CommandFlags.FireAndForget
| CommandFlags.NoRedirect | CommandFlags.NoRedirect
| CommandFlags.NoScriptCache; | CommandFlags.NoScriptCache;
private ResultBox resultBox; private IResultBox resultBox;
private ResultProcessor resultProcessor; private ResultProcessor resultProcessor;
...@@ -206,7 +206,7 @@ internal void SetScriptUnavailable() ...@@ -206,7 +206,7 @@ internal void SetScriptUnavailable()
public bool IsFireAndForget => (Flags & CommandFlags.FireAndForget) != 0; public bool IsFireAndForget => (Flags & CommandFlags.FireAndForget) != 0;
public bool IsInternalCall => (Flags & InternalCallFlag) != 0; public bool IsInternalCall => (Flags & InternalCallFlag) != 0;
public ResultBox ResultBox => resultBox; public IResultBox ResultBox => resultBox;
public abstract int ArgCount { get; } // note: over-estimate if necessary public abstract int ArgCount { get; } // note: over-estimate if necessary
...@@ -617,7 +617,15 @@ internal virtual void SetExceptionAndComplete(Exception exception, PhysicalBridg ...@@ -617,7 +617,15 @@ internal virtual void SetExceptionAndComplete(Exception exception, PhysicalBridg
bridge.CompleteSyncOrAsync(this); bridge.CompleteSyncOrAsync(this);
} }
internal bool TrySetResult<T>(T value) => resultBox is ResultBox<T> typed && typed.TrySetResult(value); internal bool TrySetResult<T>(T value)
{
if (resultBox is IResultBox<T> typed && !typed.IsFaulted)
{
typed.SetResult(value);
return true;
}
return false;
}
internal void SetEnqueued(PhysicalConnection connection) internal void SetEnqueued(PhysicalConnection connection)
{ {
...@@ -711,14 +719,14 @@ internal void SetPreferSlave() ...@@ -711,14 +719,14 @@ internal void SetPreferSlave()
Flags = (Flags & ~MaskMasterServerPreference) | CommandFlags.PreferSlave; Flags = (Flags & ~MaskMasterServerPreference) | CommandFlags.PreferSlave;
} }
internal void SetSource(ResultProcessor resultProcessor, ResultBox resultBox) internal void SetSource(ResultProcessor resultProcessor, IResultBox resultBox)
{ // note order here reversed to prevent overload resolution errors { // note order here reversed to prevent overload resolution errors
if (resultBox != null && resultBox.IsAsync) SetNeedsTimeoutCheck(); if (resultBox != null && resultBox.IsAsync) SetNeedsTimeoutCheck();
this.resultBox = resultBox; this.resultBox = resultBox;
this.resultProcessor = resultProcessor; this.resultProcessor = resultProcessor;
} }
internal void SetSource<T>(ResultBox<T> resultBox, ResultProcessor<T> resultProcessor) internal void SetSource<T>(IResultBox<T> resultBox, ResultProcessor<T> resultProcessor)
{ {
if (resultBox != null && resultBox.IsAsync) SetNeedsTimeoutCheck(); if (resultBox != null && resultBox.IsAsync) SetNeedsTimeoutCheck();
this.resultBox = resultBox; this.resultBox = resultBox;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
using System.Threading.Channels; using System.Threading.Channels;
using System.Threading.Tasks; using System.Threading.Tasks;
using Pipelines.Sockets.Unofficial.Threading; using Pipelines.Sockets.Unofficial.Threading;
using static Pipelines.Sockets.Unofficial.Threading.MutexSlim;
using PendingSubscriptionState = global::StackExchange.Redis.ConnectionMultiplexer.Subscription.PendingSubscriptionState; using PendingSubscriptionState = global::StackExchange.Redis.ConnectionMultiplexer.Subscription.PendingSubscriptionState;
namespace StackExchange.Redis namespace StackExchange.Redis
...@@ -684,7 +685,7 @@ private WriteResult WriteMessageInsideLock(PhysicalConnection physical, Message ...@@ -684,7 +685,7 @@ private WriteResult WriteMessageInsideLock(PhysicalConnection physical, Message
} }
} }
private async ValueTask<WriteResult> WriteMessageTakingDelayedWriteLockAsync(MutexSlim.AwaitableLockToken pendingLock, PhysicalConnection physical, Message message) private async ValueTask<WriteResult> WriteMessageTakingDelayedWriteLockAsync(ValueTask<LockToken> pendingLock, PhysicalConnection physical, Message message)
{ {
try try
{ {
...@@ -763,7 +764,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect ...@@ -763,7 +764,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect
message.SetEnqueued(physical); // this also records the read/write stats at this point message.SetEnqueued(physical); // this also records the read/write stats at this point
bool releaseLock = false; bool releaseLock = false;
MutexSlim.LockToken token = default; LockToken token = default;
try try
{ {
// try to acquire it synchronously // try to acquire it synchronously
...@@ -772,7 +773,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect ...@@ -772,7 +773,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect
if (!pending.IsCompletedSuccessfully) return WriteMessageTakingDelayedWriteLockAsync(pending, physical, message); if (!pending.IsCompletedSuccessfully) return WriteMessageTakingDelayedWriteLockAsync(pending, physical, message);
releaseLock = true; releaseLock = true;
token = pending.GetResult(); // we can't use "using" for this, because we might not want to kill it yet token = pending.Result; // we can't use "using" for this, because we might not want to kill it yet
if (!token.Success) // (in particular, me might hand the lifetime to CompleteWriteAndReleaseLockAsync) if (!token.Success) // (in particular, me might hand the lifetime to CompleteWriteAndReleaseLockAsync)
{ {
message.Cancel(); message.Cancel();
...@@ -807,7 +808,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect ...@@ -807,7 +808,7 @@ internal ValueTask<WriteResult> WriteMessageTakingWriteLockAsync(PhysicalConnect
} }
} }
private async ValueTask<WriteResult> CompleteWriteAndReleaseLockAsync(MutexSlim.LockToken lockToken, ValueTask<WriteResult> flush, Message message) private async ValueTask<WriteResult> CompleteWriteAndReleaseLockAsync(LockToken lockToken, ValueTask<WriteResult> flush, Message message)
{ {
using (lockToken) using (lockToken)
{ {
......
...@@ -76,10 +76,9 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -76,10 +76,9 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
} }
else else
{ {
var tcs = TaskSource.Create<T>(asyncState); var source = TaskResultBox<T>.Create(out var tcs, asyncState);
var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor);
task = tcs.Task; task = tcs.Task;
message.SetSource(source, processor);
} }
// store it // store it
......
...@@ -3714,7 +3714,7 @@ protected override SortedSetEntry[] Parse(in RawResult result) ...@@ -3714,7 +3714,7 @@ protected override SortedSetEntry[] Parse(in RawResult result)
private class StringGetWithExpiryMessage : Message.CommandKeyBase, IMultiMessage private class StringGetWithExpiryMessage : Message.CommandKeyBase, IMultiMessage
{ {
private readonly RedisCommand ttlCommand; private readonly RedisCommand ttlCommand;
private ResultBox<TimeSpan?> box; private IResultBox<TimeSpan?> box;
public StringGetWithExpiryMessage(int db, CommandFlags flags, RedisCommand ttlCommand, in RedisKey key) public StringGetWithExpiryMessage(int db, CommandFlags flags, RedisCommand ttlCommand, in RedisKey key)
: base(db, flags, RedisCommand.GET, key) : base(db, flags, RedisCommand.GET, key)
...@@ -3726,7 +3726,7 @@ public StringGetWithExpiryMessage(int db, CommandFlags flags, RedisCommand ttlCo ...@@ -3726,7 +3726,7 @@ public StringGetWithExpiryMessage(int db, CommandFlags flags, RedisCommand ttlCo
public IEnumerable<Message> GetMessages(PhysicalConnection connection) public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{ {
box = ResultBox<TimeSpan?>.Get(null); box = SimpleResultBox<TimeSpan?>.Create();
var ttl = Message.Create(Db, Flags, ttlCommand, Key); var ttl = Message.Create(Db, Flags, ttlCommand, Key);
var proc = ttlCommand == RedisCommand.PTTL ? ResultProcessor.TimeSpanFromMilliseconds : ResultProcessor.TimeSpanFromSeconds; var proc = ttlCommand == RedisCommand.PTTL ? ResultProcessor.TimeSpanFromMilliseconds : ResultProcessor.TimeSpanFromSeconds;
ttl.SetSource(proc, box); ttl.SetSource(proc, box);
...@@ -3738,7 +3738,7 @@ public bool UnwrapValue(out TimeSpan? value, out Exception ex) ...@@ -3738,7 +3738,7 @@ public bool UnwrapValue(out TimeSpan? value, out Exception ex)
{ {
if (box != null) if (box != null)
{ {
ResultBox<TimeSpan?>.UnwrapAndRecycle(box, false, out value, out ex); value = box.GetResult(out ex);
box = null; box = null;
return ex == null; return ex == null;
} }
......
...@@ -259,12 +259,11 @@ private PendingSubscriptionState(object asyncState, RedisChannel channel, Subscr ...@@ -259,12 +259,11 @@ private PendingSubscriptionState(object asyncState, RedisChannel channel, Subscr
: (channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE); : (channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE);
var msg = Message.Create(-1, flags, cmd, channel); var msg = Message.Create(-1, flags, cmd, channel);
if (internalCall) msg.SetInternalCall(); if (internalCall) msg.SetInternalCall();
var taskSource = TaskSource.Create<bool>(asyncState);
var source = ResultBox<bool>.Get(taskSource); var source = TaskResultBox<bool>.Create(out _taskSource, asyncState);
msg.SetSource(ResultProcessor.TrackSubscriptions, source); msg.SetSource(ResultProcessor.TrackSubscriptions, source);
Subscription = subscription; Subscription = subscription;
_taskSource = taskSource;
Message = msg; Message = msg;
IsSlave = isSlave; IsSlave = isSlave;
} }
......
...@@ -73,15 +73,14 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -73,15 +73,14 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
} }
else else
{ {
var tcs = TaskSource.Create<T>(asyncState, TaskCreationOptions.RunContinuationsAsynchronously); var source = TaskResultBox<T>.Create(out var tcs, asyncState, TaskCreationOptions.RunContinuationsAsynchronously);
var source = ResultBox<T>.Get(tcs);
message.SetSource(source, processor); message.SetSource(source, processor);
task = tcs.Task; task = tcs.Task;
} }
// prepare an outer-command that decorates that, but expects QUEUED // prepare an outer-command that decorates that, but expects QUEUED
var queued = new QueuedMessage(message); var queued = new QueuedMessage(message);
var wasQueued = ResultBox<bool>.Get(null); var wasQueued = SimpleResultBox<bool>.Create();
queued.SetSource(wasQueued, QueuedProcessor.Default); queued.SetSource(wasQueued, QueuedProcessor.Default);
// store it, and return the task of the *outer* command // store it, and return the task of the *outer* command
...@@ -100,7 +99,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr ...@@ -100,7 +99,7 @@ internal override Task<T> ExecuteAsync<T>(Message message, ResultProcessor<T> pr
// think it should be! // think it should be!
var sel = PhysicalConnection.GetSelectDatabaseCommand(message.Db); var sel = PhysicalConnection.GetSelectDatabaseCommand(message.Db);
queued = new QueuedMessage(sel); queued = new QueuedMessage(sel);
wasQueued = ResultBox<bool>.Get(null); wasQueued = SimpleResultBox<bool>.Create();
queued.SetSource(wasQueued, QueuedProcessor.Default); queued.SetSource(wasQueued, QueuedProcessor.Default);
_pending.Add(queued); _pending.Add(queued);
break; break;
...@@ -240,7 +239,7 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) ...@@ -240,7 +239,7 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy)
public IEnumerable<Message> GetMessages(PhysicalConnection connection) public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{ {
ResultBox lastBox = null; IResultBox lastBox = null;
var bridge = connection.BridgeCouldBeNull; var bridge = connection.BridgeCouldBeNull;
if (bridge == null) throw new ObjectDisposedException(connection.ToString()); if (bridge == null) throw new ObjectDisposedException(connection.ToString());
...@@ -268,7 +267,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -268,7 +267,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{ {
// need to have locked them before sending them // need to have locked them before sending them
// to guarantee that we see the pulse // to guarantee that we see the pulse
ResultBox latestBox = conditions[i].GetBox(); IResultBox latestBox = conditions[i].GetBox();
Monitor.Enter(latestBox); Monitor.Enter(latestBox);
if (lastBox != null) Monitor.Exit(lastBox); if (lastBox != null) Monitor.Exit(lastBox);
lastBox = latestBox; lastBox = latestBox;
...@@ -328,7 +327,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -328,7 +327,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
if (explicitCheckForQueued) if (explicitCheckForQueued)
{ // need to have locked them before sending them { // need to have locked them before sending them
// to guarantee that we see the pulse // to guarantee that we see the pulse
ResultBox thisBox = op.ResultBox; IResultBox thisBox = op.ResultBox;
if (thisBox != null) if (thisBox != null)
{ {
Monitor.Enter(thisBox); Monitor.Enter(thisBox);
......
...@@ -4,146 +4,153 @@ ...@@ -4,146 +4,153 @@
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
internal abstract class ResultBox internal interface IResultBox
{ {
protected Exception _exception; bool IsAsync { get; }
public abstract bool IsAsync { get; } bool IsFaulted { get; }
public bool IsFaulted => _exception != null; void SetException(Exception ex);
bool TryComplete(bool isAsync);
void Cancel();
}
internal interface IResultBox<T> : IResultBox
{
T GetResult(out Exception ex, bool canRecycle = false);
void SetResult(T value);
}
public void SetException(Exception exception) => _exception = exception ?? s_cancelled; internal abstract class SimpleResultBox : IResultBox
{
private volatile Exception _exception;
public abstract bool TryComplete(bool isAsync); bool IResultBox.IsAsync => false;
bool IResultBox.IsFaulted => _exception != null;
void IResultBox.SetException(Exception exception) => _exception = exception ?? CancelledException;
void IResultBox.Cancel() => _exception = CancelledException;
public void Cancel() => _exception = s_cancelled; bool IResultBox.TryComplete(bool isAsync)
{
lock (this)
{ // tell the waiting thread that we're done
Monitor.PulseAll(this);
}
ConnectionMultiplexer.TraceWithoutContext("Pulsed", "Result");
return true;
}
// in theory nobody should directly observe this; the only things // in theory nobody should directly observe this; the only things
// that call Cancel are transactions etc - TCS-based, and we detect // that call Cancel are transactions etc - TCS-based, and we detect
// that and use TrySetCanceled instead // that and use TrySetCanceled instead
// about any confusion in stack-trace // about any confusion in stack-trace
private static readonly Exception s_cancelled = new TaskCanceledException(); internal static readonly Exception CancelledException = new TaskCanceledException();
protected Exception Exception
{
get => _exception;
set => _exception = value;
}
} }
internal sealed class ResultBox<T> : ResultBox internal sealed class SimpleResultBox<T> : SimpleResultBox, IResultBox<T>
{ {
private static readonly ResultBox<T>[] store = new ResultBox<T>[64]; private SimpleResultBox() { }
private object stateOrCompletionSource; private T _value;
private int _usageCount;
private T value;
public ResultBox(object stateOrCompletionSource) [ThreadStatic]
{ private static SimpleResultBox<T> _perThreadInstance;
this.stateOrCompletionSource = stateOrCompletionSource;
_usageCount = 1;
}
public static ResultBox<T> Get(object stateOrCompletionSource) public static IResultBox<T> Create() => new SimpleResultBox<T>();
public static IResultBox<T> Get() // includes recycled boxes; used from sync, so makes re-use easy
{ {
ResultBox<T> found; var obj = _perThreadInstance ?? new SimpleResultBox<T>();
for (int i = 0; i < store.Length; i++) _perThreadInstance = null; // in case of oddness; only set back when recycled
{ return obj;
if ((found = Interlocked.Exchange(ref store[i], null)) != null)
{
found.Reset(stateOrCompletionSource);
return found;
}
}
return new ResultBox<T>(stateOrCompletionSource);
} }
void IResultBox<T>.SetResult(T value) => _value = value;
public static void UnwrapAndRecycle(ResultBox<T> box, bool recycle, out T value, out Exception exception) T IResultBox<T>.GetResult(out Exception ex, bool canRecycle)
{ {
if (box == null) var value = _value;
{ ex = Exception;
value = default(T); if (canRecycle)
exception = null;
}
else
{ {
value = box.value; Exception = null;
exception = box._exception; _value = default;
box.value = default(T); _perThreadInstance = this;
box._exception = null;
if (recycle)
{
var newCount = Interlocked.Decrement(ref box._usageCount);
if (newCount != 0)
throw new InvalidOperationException($"Result box count error: is {newCount} in UnwrapAndRecycle (should be 0)");
// Clear state prior to recycling, so as not to root it
box.stateOrCompletionSource = null;
for (int i = 0; i < store.Length; i++)
{
if (Interlocked.CompareExchange(ref store[i], box, null) == null) return;
}
}
} }
return value;
} }
}
public void SetResult(T value) internal sealed class TaskResultBox<T> : TaskCompletionSource<T>, IResultBox<T>
{ {
this.value = value; // you might be asking "wait, doesn't the Task own these?", to which
} // I say: no; we can't set *immediately* due to thread-theft etc, hence
// the fun TryComplete indirection - so we need somewhere to buffer them
private volatile Exception _exception;
private T _value;
private TaskResultBox(object asyncState, TaskCreationOptions creationOptions) : base(asyncState, creationOptions)
{ }
bool IResultBox.IsAsync => true;
bool IResultBox.IsFaulted => _exception != null;
internal bool TrySetResult(T value) void IResultBox.Cancel() => _exception = SimpleResultBox.CancelledException;
void IResultBox.SetException(Exception ex) => _exception = ex ?? SimpleResultBox.CancelledException;
void IResultBox<T>.SetResult(T value) => _value = value;
T IResultBox<T>.GetResult(out Exception ex, bool _)
{ {
if (_exception != null) return false; ex = _exception;
this.value = value; return _value;
return true; // nothing to do re recycle: TaskCompletionSource<T> cannot be recycled
} }
public override bool IsAsync => stateOrCompletionSource is TaskCompletionSource<T>; bool IResultBox.TryComplete(bool isAsync)
public override bool TryComplete(bool isAsync)
{ {
if (stateOrCompletionSource is TaskCompletionSource<T> tcs) if (isAsync || (Task.CreationOptions & TaskCreationOptions.RunContinuationsAsynchronously) != 0)
{ {
if (isAsync || (tcs.Task.CreationOptions & TaskCreationOptions.RunContinuationsAsynchronously) != 0) // either on the async completion step, or the task is guarded
// againsts thread-stealing; complete it directly
// (note: RunContinuationsAsynchronously is only usable from NET46)
var val = _value;
var ex = _exception;
if (ex == null)
{ {
// either on the async completion step, or the task is guarded TrySetResult(val);
// againsts thread-stealing; complete it directly
// (note: RunContinuationsAsynchronously is only usable from NET46)
UnwrapAndRecycle(this, true, out T val, out Exception ex);
if (ex == null)
{
tcs.TrySetResult(val);
}
else
{
if (ex is TaskCanceledException) tcs.TrySetCanceled();
else tcs.TrySetException(ex);
// mark it as observed
GC.KeepAlive(tcs.Task.Exception);
GC.SuppressFinalize(tcs.Task);
}
return true;
} }
else else
{ {
// could be thread-stealing continuations; push to async to preserve the reader thread if (ex is TaskCanceledException) TrySetCanceled();
return false; else TrySetException(ex);
// mark any exception as observed
var task = Task;
GC.KeepAlive(task.Exception);
GC.SuppressFinalize(task);
} }
return true;
} }
else else
{ {
lock (this) // could be thread-stealing continuations; push to async to preserve the reader thread
{ // tell the waiting thread that we're done return false;
Monitor.PulseAll(this);
}
ConnectionMultiplexer.TraceWithoutContext("Pulsed", "Result");
return true;
} }
} }
private void Reset(object stateOrCompletionSource) public static IResultBox<T> Create(out TaskCompletionSource<T> source, object asyncState, TaskCreationOptions creationOptions = TaskCreationOptions.None)
{ {
var newCount = Interlocked.Increment(ref _usageCount); // it might look a little odd to return the same object as two different things,
if (newCount != 1) throw new InvalidOperationException($"Result box count error: is {newCount} in Reset (should be 1)"); // but that's because it is serving two purposes, and I want to make it clear
value = default(T); // how it is being used in those 2 different ways; also, the *fact* that they
_exception = null; // are the same underlying object is an implementation detail that the rest of
// the code doesn't need to know about
this.stateOrCompletionSource = stateOrCompletionSource; var obj = new TaskResultBox<T>(asyncState, creationOptions);
source = obj;
return obj;
} }
} }
} }
...@@ -2032,7 +2032,7 @@ internal abstract class ResultProcessor<T> : ResultProcessor ...@@ -2032,7 +2032,7 @@ internal abstract class ResultProcessor<T> : ResultProcessor
protected void SetResult(Message message, T value) protected void SetResult(Message message, T value)
{ {
if (message == null) return; if (message == null) return;
var box = message.ResultBox as ResultBox<T>; var box = message.ResultBox as IResultBox<T>;
message.SetResponseReceived(); message.SetResponseReceived();
box?.SetResult(value); box?.SetResult(value);
......
...@@ -574,8 +574,7 @@ private static async Task<T> WriteDirectAsync_Awaited<T>(ServerEndPoint @this, M ...@@ -574,8 +574,7 @@ private static async Task<T> WriteDirectAsync_Awaited<T>(ServerEndPoint @this, M
internal Task<T> WriteDirectAsync<T>(Message message, ResultProcessor<T> processor, object asyncState = null, PhysicalBridge bridge = null) internal Task<T> WriteDirectAsync<T>(Message message, ResultProcessor<T> processor, object asyncState = null, PhysicalBridge bridge = null)
{ {
var tcs = TaskSource.Create<T>(asyncState); var source = TaskResultBox<T>.Create(out var tcs, asyncState);
var source = ResultBox<T>.Get(tcs);
message.SetSource(processor, source); message.SetSource(processor, source);
if (bridge == null) bridge = GetBridge(message.Command); if (bridge == null) bridge = GetBridge(message.Command);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Pipelines.Sockets.Unofficial" Version="1.0.18" /> <PackageReference Include="Pipelines.Sockets.Unofficial" Version="1.0.20" />
<PackageReference Include="System.Diagnostics.PerformanceCounter" Version="4.5.0" /> <PackageReference Include="System.Diagnostics.PerformanceCounter" Version="4.5.0" />
<PackageReference Include="System.IO.Pipelines" Version="4.5.1" /> <PackageReference Include="System.IO.Pipelines" Version="4.5.1" />
<PackageReference Include="System.Threading.Channels" Version="4.5.0" /> <PackageReference Include="System.Threading.Channels" Version="4.5.0" />
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment