Unverified Commit ec5ba309 authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Async write path (#1056)

* start work on a true async write path (no sync flush); at the moment only existing async code uses this - once stable, we can move more code to the full async path

* finish prep work before tackling subscriber lock (note: still lots of handshake options to asyncify)

* implement a background queue that represents the (necessarily ordered) subscription operations, rather than executing them synchronously (sync over async, locks, etc)
parent 89b8f64b
...@@ -164,7 +164,7 @@ private async Task OnMessageSyncImpl() ...@@ -164,7 +164,7 @@ private async Task OnMessageSyncImpl()
while (!Completion.IsCompleted) while (!Completion.IsCompleted)
{ {
ChannelMessage next; ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); } try { if (!TryRead(out next)) next = await ReadAsync().ForAwait(); }
catch (ChannelClosedException) { break; } // expected catch (ChannelClosedException) { break; } // expected
catch (Exception ex) catch (Exception ex)
{ {
...@@ -195,7 +195,7 @@ private async Task OnMessageAsyncImpl() ...@@ -195,7 +195,7 @@ private async Task OnMessageAsyncImpl()
while (!Completion.IsCompleted) while (!Completion.IsCompleted)
{ {
ChannelMessage next; ChannelMessage next;
try { if (!TryRead(out next)) next = await ReadAsync().ConfigureAwait(false); } try { if (!TryRead(out next)) next = await ReadAsync().ForAwait(); }
catch (ChannelClosedException) { break; } // expected catch (ChannelClosedException) { break; } // expected
catch (Exception ex) catch (Exception ex)
{ {
...@@ -206,7 +206,7 @@ private async Task OnMessageAsyncImpl() ...@@ -206,7 +206,7 @@ private async Task OnMessageAsyncImpl()
try try
{ {
var task = handler(next); var task = handler(next);
if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ConfigureAwait(false); if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ForAwait();
} }
catch { } // matches MessageCompletable catch { } // matches MessageCompletable
} }
...@@ -229,7 +229,7 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl ...@@ -229,7 +229,7 @@ internal async Task UnsubscribeAsyncImpl(Exception error = null, CommandFlags fl
_parent = null; _parent = null;
if (parent != null) if (parent != null)
{ {
await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ConfigureAwait(false); await parent.UnsubscribeAsync(Channel, HandleMessage, flags).ForAwait();
} }
_queue.Writer.TryComplete(error); _queue.Writer.TryComplete(error);
} }
......
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
#pragma warning disable RCS1231
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
/// <summary> /// <summary>
...@@ -303,6 +305,8 @@ public static Condition StringNotEqual(RedisKey key, RedisValue value) ...@@ -303,6 +305,8 @@ public static Condition StringNotEqual(RedisKey key, RedisValue value)
/// <param name="count">The number of members which sorted set must not have.</param> /// <param name="count">The number of members which sorted set must not have.</param>
public static Condition SortedSetScoreNotExists(RedisKey key, RedisValue score, RedisValue count) => new SortedSetScoreCondition(key, score, false, count); public static Condition SortedSetScoreNotExists(RedisKey key, RedisValue score, RedisValue count) => new SortedSetScoreCondition(key, score, false, count);
#pragma warning restore RCS1231
internal abstract void CheckCommands(CommandMap commandMap); internal abstract void CheckCommands(CommandMap commandMap);
internal abstract IEnumerable<Message> CreateMessages(int db, ResultBox resultBox); internal abstract IEnumerable<Message> CreateMessages(int db, ResultBox resultBox);
...@@ -314,12 +318,14 @@ internal sealed class ConditionProcessor : ResultProcessor<bool> ...@@ -314,12 +318,14 @@ internal sealed class ConditionProcessor : ResultProcessor<bool>
{ {
public static readonly ConditionProcessor Default = new ConditionProcessor(); public static readonly ConditionProcessor Default = new ConditionProcessor();
public static Message CreateMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, RedisKey key, RedisValue value = default(RedisValue)) #pragma warning disable RCS1231 // Make parameter ref read-only.
public static Message CreateMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, in RedisKey key, RedisValue value = default(RedisValue))
#pragma warning restore RCS1231 // Make parameter ref read-only.
{ {
return new ConditionMessage(condition, db, flags, command, key, value); return new ConditionMessage(condition, db, flags, command, key, value);
} }
public static Message CreateMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, RedisKey key, RedisValue value, RedisValue value1) public static Message CreateMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, in RedisKey key, in RedisValue value, in RedisValue value1)
{ {
return new ConditionMessage(condition, db, flags, command, key, value, value1); return new ConditionMessage(condition, db, flags, command, key, value, value1);
} }
...@@ -343,14 +349,14 @@ private class ConditionMessage : Message.CommandKeyBase ...@@ -343,14 +349,14 @@ private class ConditionMessage : Message.CommandKeyBase
private readonly RedisValue value; private readonly RedisValue value;
private readonly RedisValue value1; private readonly RedisValue value1;
public ConditionMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, RedisKey key, RedisValue value) public ConditionMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, in RedisKey key, in RedisValue value)
: base(db, flags, command, key) : base(db, flags, command, key)
{ {
Condition = condition; Condition = condition;
this.value = value; // note no assert here this.value = value; // note no assert here
} }
public ConditionMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, RedisKey key, RedisValue value, RedisValue value1) public ConditionMessage(Condition condition, int db, CommandFlags flags, RedisCommand command, in RedisKey key, in RedisValue value, in RedisValue value1)
: this(condition, db, flags, command, key, value) : this(condition, db, flags, command, key, value)
{ {
this.value1 = value1; // note no assert here this.value1 = value1; // note no assert here
...@@ -391,7 +397,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map) ...@@ -391,7 +397,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map)
return new ExistsCondition(map(key), type, expectedValue, expectedResult); return new ExistsCondition(map(key), type, expectedValue, expectedResult);
} }
public ExistsCondition(RedisKey key, RedisType type, RedisValue expectedValue, bool expectedResult) public ExistsCondition(in RedisKey key, RedisType type, in RedisValue expectedValue, bool expectedResult)
{ {
if (key.IsNull) throw new ArgumentException("key"); if (key.IsNull) throw new ArgumentException("key");
this.key = key; this.key = key;
...@@ -481,7 +487,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map) ...@@ -481,7 +487,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map)
private readonly RedisType type; private readonly RedisType type;
private readonly RedisCommand cmd; private readonly RedisCommand cmd;
public EqualsCondition(RedisKey key, RedisType type, RedisValue memberName, bool expectedEqual, RedisValue expectedValue) public EqualsCondition(in RedisKey key, RedisType type, in RedisValue memberName, bool expectedEqual, in RedisValue expectedValue)
{ {
if (key.IsNull) throw new ArgumentException("key"); if (key.IsNull) throw new ArgumentException("key");
this.key = key; this.key = key;
...@@ -575,7 +581,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map) ...@@ -575,7 +581,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map)
private readonly long index; private readonly long index;
private readonly RedisValue? expectedValue; private readonly RedisValue? expectedValue;
private readonly RedisKey key; private readonly RedisKey key;
public ListCondition(RedisKey key, long index, bool expectedResult, RedisValue? expectedValue) public ListCondition(in RedisKey key, long index, bool expectedResult, in RedisValue? expectedValue)
{ {
if (key.IsNull) throw new ArgumentException(nameof(key)); if (key.IsNull) throw new ArgumentException(nameof(key));
this.key = key; this.key = key;
...@@ -645,7 +651,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map) ...@@ -645,7 +651,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map)
private readonly RedisType type; private readonly RedisType type;
private readonly RedisCommand cmd; private readonly RedisCommand cmd;
public LengthCondition(RedisKey key, RedisType type, int compareToResult, long expectedLength) public LengthCondition(in RedisKey key, RedisType type, int compareToResult, long expectedLength)
{ {
if (key.IsNull) throw new ArgumentException(nameof(key)); if (key.IsNull) throw new ArgumentException(nameof(key));
this.key = key; this.key = key;
...@@ -737,7 +743,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map) ...@@ -737,7 +743,7 @@ internal override Condition MapKeys(Func<RedisKey, RedisKey> map)
private readonly RedisValue sortedSetScore, expectedValue; private readonly RedisValue sortedSetScore, expectedValue;
private readonly RedisKey key; private readonly RedisKey key;
public SortedSetScoreCondition(RedisKey key, RedisValue sortedSetScore, bool expectedEqual, RedisValue expectedValue) public SortedSetScoreCondition(in RedisKey key, in RedisValue sortedSetScore, bool expectedEqual, in RedisValue expectedValue)
{ {
if (key.IsNull) if (key.IsNull)
{ {
......
...@@ -291,10 +291,13 @@ public int ConnectTimeout ...@@ -291,10 +291,13 @@ public int ConnectTimeout
/// </summary> /// </summary>
public bool HighPrioritySocketThreads { get { return highPrioritySocketThreads ?? true; } set { highPrioritySocketThreads = value; } } public bool HighPrioritySocketThreads { get { return highPrioritySocketThreads ?? true; } set { highPrioritySocketThreads = value; } }
// Use coalesce expression.
/// <summary> /// <summary>
/// Specifies the time in seconds at which connections should be pinged to ensure validity /// Specifies the time in seconds at which connections should be pinged to ensure validity
/// </summary> /// </summary>
#pragma warning disable RCS1128
public int KeepAlive { get { return keepAlive.GetValueOrDefault(-1); } set { keepAlive = value; } } public int KeepAlive { get { return keepAlive.GetValueOrDefault(-1); } set { keepAlive = value; } }
#pragma warning restore RCS1128 // Use coalesce expression.
/// <summary> /// <summary>
/// The password to use to authenticate with the server. /// The password to use to authenticate with the server.
...@@ -363,7 +366,9 @@ public bool PreserveAsyncOrder ...@@ -363,7 +366,9 @@ public bool PreserveAsyncOrder
/// <summary> /// <summary>
/// Specifies the time in milliseconds that the system should allow for synchronous operations (defaults to 1 second) /// Specifies the time in milliseconds that the system should allow for synchronous operations (defaults to 1 second)
/// </summary> /// </summary>
#pragma warning disable RCS1128
public int SyncTimeout { get { return syncTimeout.GetValueOrDefault(5000); } set { syncTimeout = value; } } public int SyncTimeout { get { return syncTimeout.GetValueOrDefault(5000); } set { syncTimeout = value; } }
#pragma warning restore RCS1128
/// <summary> /// <summary>
/// Tie-breaker used to choose between masters (must match the endpoint exactly) /// Tie-breaker used to choose between masters (must match the endpoint exactly)
...@@ -383,7 +388,9 @@ public bool PreserveAsyncOrder ...@@ -383,7 +388,9 @@ public bool PreserveAsyncOrder
/// <summary> /// <summary>
/// Check configuration every n seconds (every minute by default) /// Check configuration every n seconds (every minute by default)
/// </summary> /// </summary>
#pragma warning disable RCS1128
public int ConfigCheckSeconds { get { return configCheckSeconds.GetValueOrDefault(60); } set { configCheckSeconds = value; } } public int ConfigCheckSeconds { get { return configCheckSeconds.GetValueOrDefault(60); } set { configCheckSeconds = value; } }
#pragma warning restore RCS1128
/// <summary> /// <summary>
/// Parse the configuration from a comma-delimited configuration string /// Parse the configuration from a comma-delimited configuration string
......
...@@ -71,12 +71,12 @@ private async Task CloneAsync(string path, PipeReader from, PipeWriter to) ...@@ -71,12 +71,12 @@ private async Task CloneAsync(string path, PipeReader from, PipeWriter to)
arr = new ArraySegment<byte>(tmp, 0, segment.Length); arr = new ArraySegment<byte>(tmp, 0, segment.Length);
leased = true; leased = true;
} }
await file.WriteAsync(arr.Array, arr.Offset, arr.Count); await file.WriteAsync(arr.Array, arr.Offset, arr.Count).ForAwait();
await file.FlushAsync(); await file.FlushAsync().ForAwait();
if (leased) ArrayPool<byte>.Shared.Return(arr.Array); if (leased) ArrayPool<byte>.Shared.Return(arr.Array);
// and flush it upstream // and flush it upstream
await to.WriteAsync(segment); await to.WriteAsync(segment).ForAwait();
} }
} }
from.AdvanceTo(buffer.End); from.AdvanceTo(buffer.End);
......
This diff is collapsed.
...@@ -232,8 +232,6 @@ private enum ReadMode : byte ...@@ -232,8 +232,6 @@ private enum ReadMode : byte
public bool TransactionActive { get; internal set; } public bool TransactionActive { get; internal set; }
partial void ShouldIgnoreConnect(ref bool ignore);
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")]
internal void Shutdown() internal void Shutdown()
{ {
...@@ -275,7 +273,7 @@ public void Dispose() ...@@ -275,7 +273,7 @@ public void Dispose()
private async Task AwaitedFlush(ValueTask<FlushResult> flush) private async Task AwaitedFlush(ValueTask<FlushResult> flush)
{ {
await flush; await flush.ForAwait();
_writeStatus = WriteStatus.Flushed; _writeStatus = WriteStatus.Flushed;
UpdateLastWriteTime(); UpdateLastWriteTime();
} }
...@@ -839,23 +837,50 @@ internal static int WriteRaw(Span<byte> span, long value, bool withLengthPrefix ...@@ -839,23 +837,50 @@ internal static int WriteRaw(Span<byte> span, long value, bool withLengthPrefix
return WriteCrlf(span, offset); return WriteCrlf(span, offset);
} }
internal WriteResult FlushSync(bool throwOnFailure = false) private static async ValueTask<WriteResult> FlushAsync_Awaited(PhysicalConnection connection, ValueTask<FlushResult> flush, bool throwOnFailure)
{
try
{
await flush.ForAwait();
connection._writeStatus = WriteStatus.Flushed;
connection.UpdateLastWriteTime();
return WriteResult.Success;
}
catch (ConnectionResetException ex) when (!throwOnFailure)
{
connection.RecordConnectionFailed(ConnectionFailureType.SocketClosed, ex);
return WriteResult.WriteFailure;
}
}
[Obsolete("this is an anti-pattern; work to reduce reliance on this is in progress")]
internal WriteResult FlushSync(bool throwOnFailure, int millisecondsTimeout)
{
var flush = FlushAsync(throwOnFailure);
if (!flush.IsCompletedSuccessfully)
{
// here lies the evil
if (!flush.AsTask().Wait(millisecondsTimeout)) throw new TimeoutException("timeout while synchronously flushing");
}
return flush.Result;
}
internal ValueTask<WriteResult> FlushAsync(bool throwOnFailure)
{ {
var tmp = _ioPipe?.Output; var tmp = _ioPipe?.Output;
if (tmp == null) return WriteResult.NoConnectionAvailable; if (tmp == null) return new ValueTask<WriteResult>(WriteResult.NoConnectionAvailable);
try try
{ {
_writeStatus = WriteStatus.Flushing; _writeStatus = WriteStatus.Flushing;
var flush = tmp.FlushAsync(); var flush = tmp.FlushAsync();
if (!flush.IsCompletedSuccessfully) flush.AsTask().Wait(); if (!flush.IsCompletedSuccessfully) return FlushAsync_Awaited(this, flush, throwOnFailure);
_writeStatus = WriteStatus.Flushed; _writeStatus = WriteStatus.Flushed;
UpdateLastWriteTime(); UpdateLastWriteTime();
return WriteResult.Success; return new ValueTask<WriteResult>(WriteResult.Success);
} }
catch (ConnectionResetException ex) when (!throwOnFailure) catch (ConnectionResetException ex) when (!throwOnFailure)
{ {
RecordConnectionFailed(ConnectionFailureType.SocketClosed, ex); RecordConnectionFailed(ConnectionFailureType.SocketClosed, ex);
return WriteResult.WriteFailure; return new ValueTask<WriteResult>(WriteResult.WriteFailure);
} }
} }
...@@ -874,7 +899,9 @@ private static void WriteUnifiedBlob(PipeWriter writer, byte[] value) ...@@ -874,7 +899,9 @@ private static void WriteUnifiedBlob(PipeWriter writer, byte[] value)
} }
} }
#pragma warning disable RCS1231 // Make parameter ref read-only.
private static void WriteUnifiedSpan(PipeWriter writer, ReadOnlySpan<byte> value) private static void WriteUnifiedSpan(PipeWriter writer, ReadOnlySpan<byte> value)
#pragma warning restore RCS1231 // Make parameter ref read-only.
{ {
// ${len}\r\n = 3 + MaxInt32TextLen // ${len}\r\n = 3 + MaxInt32TextLen
// {value}\r\n = 2 + value.Length // {value}\r\n = 2 + value.Length
...@@ -949,8 +976,8 @@ internal void WriteSha1AsHex(byte[] value) ...@@ -949,8 +976,8 @@ internal void WriteSha1AsHex(byte[] value)
for (int i = 0; i < value.Length; i++) for (int i = 0; i < value.Length; i++)
{ {
var b = value[i]; var b = value[i];
span[offset++] = ToHexNibble(value[i] >> 4); span[offset++] = ToHexNibble(b >> 4);
span[offset++] = ToHexNibble(value[i] & 15); span[offset++] = ToHexNibble(b & 15);
} }
span[offset++] = (byte)'\r'; span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n'; span[offset++] = (byte)'\n';
...@@ -1556,7 +1583,9 @@ private static RawResult ParseInlineProtocol(in RawResult line) ...@@ -1556,7 +1583,9 @@ private static RawResult ParseInlineProtocol(in RawResult line)
if (!line.HasValue) return RawResult.Nil; // incomplete line if (!line.HasValue) return RawResult.Nil; // incomplete line
int count = 0; int count = 0;
foreach (var token in line.GetInlineTokenizer()) count++; #pragma warning disable IDE0059
foreach (var _ in line.GetInlineTokenizer()) count++;
#pragma warning restore IDE0059
var oversized = ArrayPool<RawResult>.Shared.Rent(count); var oversized = ArrayPool<RawResult>.Shared.Rent(count);
count = 0; count = 0;
foreach (var token in line.GetInlineTokenizer()) foreach (var token in line.GetInlineTokenizer())
......
...@@ -54,7 +54,7 @@ internal virtual T ExecuteSync<T>(Message message, ResultProcessor<T> processor, ...@@ -54,7 +54,7 @@ internal virtual T ExecuteSync<T>(Message message, ResultProcessor<T> processor,
return multiplexer.ExecuteSyncImpl<T>(message, processor, server); return multiplexer.ExecuteSyncImpl<T>(message, processor, server);
} }
internal virtual RedisFeatures GetFeatures(RedisKey key, CommandFlags flags, out ServerEndPoint server) internal virtual RedisFeatures GetFeatures(in RedisKey key, CommandFlags flags, out ServerEndPoint server)
{ {
server = multiplexer.SelectServer(RedisCommand.PING, flags, key); server = multiplexer.SelectServer(RedisCommand.PING, flags, key);
var version = server == null ? multiplexer.RawConfig.DefaultVersion : server.Version; var version = server == null ? multiplexer.RawConfig.DefaultVersion : server.Version;
...@@ -116,7 +116,7 @@ private ResultProcessor.TimingProcessor.TimerMessage GetTimerMessage(CommandFlag ...@@ -116,7 +116,7 @@ private ResultProcessor.TimingProcessor.TimerMessage GetTimerMessage(CommandFlag
internal static class CursorUtils internal static class CursorUtils
{ {
internal const int Origin = 0, DefaultPageSize = 10; internal const int Origin = 0, DefaultPageSize = 10;
internal static bool IsNil(RedisValue pattern) internal static bool IsNil(in RedisValue pattern)
{ {
if (pattern.IsNullOrEmpty) return true; if (pattern.IsNullOrEmpty) return true;
if (pattern.IsInteger) return false; if (pattern.IsInteger) return false;
...@@ -231,7 +231,7 @@ private enum State : byte ...@@ -231,7 +231,7 @@ private enum State : byte
Disposed, Disposed,
} }
private void ProcessReply(ScanResult result) private void ProcessReply(in ScanResult result)
{ {
currentCursor = nextCursor; currentCursor = nextCursor;
nextCursor = result.Cursor; nextCursor = result.Cursor;
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
#pragma warning disable RCS1231 // Make parameter ref read-only.
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
internal sealed class RedisServer : RedisBase, IServer internal sealed class RedisServer : RedisBase, IServer
...@@ -583,7 +585,7 @@ internal override T ExecuteSync<T>(Message message, ResultProcessor<T> processor ...@@ -583,7 +585,7 @@ internal override T ExecuteSync<T>(Message message, ResultProcessor<T> processor
return base.ExecuteSync<T>(message, processor, server); return base.ExecuteSync<T>(message, processor, server);
} }
internal override RedisFeatures GetFeatures(RedisKey key, CommandFlags flags, out ServerEndPoint server) internal override RedisFeatures GetFeatures(in RedisKey key, CommandFlags flags, out ServerEndPoint server)
{ {
server = this.server; server = this.server;
return new RedisFeatures(server.Version); return new RedisFeatures(server.Version);
...@@ -607,7 +609,9 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None) ...@@ -607,7 +609,9 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None)
{ {
var del = Message.Create(0, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.DEL, (RedisKey)configuration.TieBreaker); var del = Message.Create(0, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.DEL, (RedisKey)configuration.TieBreaker);
del.SetInternalCall(); del.SetInternalCall();
server.WriteDirectFireAndForget(del, ResultProcessor.Boolean); #pragma warning disable CS0618
server.WriteDirectFireAndForgetSync(del, ResultProcessor.Boolean);
#pragma warning restore CS0618
} }
ExecuteSync(slaveofMsg, ResultProcessor.DemandOK); ExecuteSync(slaveofMsg, ResultProcessor.DemandOK);
...@@ -617,7 +621,9 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None) ...@@ -617,7 +621,9 @@ public void SlaveOf(EndPoint master, CommandFlags flags = CommandFlags.None)
{ {
var pub = Message.Create(-1, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.PUBLISH, (RedisValue)channel, RedisLiterals.Wildcard); var pub = Message.Create(-1, CommandFlags.FireAndForget | CommandFlags.NoRedirect, RedisCommand.PUBLISH, (RedisValue)channel, RedisLiterals.Wildcard);
pub.SetInternalCall(); pub.SetInternalCall();
server.WriteDirectFireAndForget(pub, ResultProcessor.Int64); #pragma warning disable CS0618
server.WriteDirectFireAndForgetSync(pub, ResultProcessor.Int64);
#pragma warning restore CS0618
} }
} }
......
...@@ -127,7 +127,7 @@ internal void ResendSubscriptions(ServerEndPoint server) ...@@ -127,7 +127,7 @@ internal void ResendSubscriptions(ServerEndPoint server)
} }
} }
internal bool SubscriberConnected(RedisChannel channel = default(RedisChannel)) internal bool SubscriberConnected(in RedisChannel channel = default(RedisChannel))
{ {
var server = GetSubscribedServer(channel); var server = GetSubscribedServer(channel);
if (server != null) return server.IsConnected; if (server != null) return server.IsConnected;
...@@ -149,7 +149,7 @@ internal long ValidateSubscriptions() ...@@ -149,7 +149,7 @@ internal long ValidateSubscriptions()
} }
} }
private sealed class Subscription internal sealed class Subscription
{ {
private Action<RedisChannel, RedisValue> _asyncHandler, _syncHandler; private Action<RedisChannel, RedisValue> _asyncHandler, _syncHandler;
private ServerEndPoint owner; private ServerEndPoint owner;
...@@ -195,25 +195,79 @@ public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value) ...@@ -195,25 +195,79 @@ public bool Remove(bool asAsync, Action<RedisChannel, RedisValue> value)
public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall) public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{ {
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE; var selected = multiplexer.SelectServer(RedisCommand.SUBSCRIBE, flags, default(RedisKey));
var selected = multiplexer.SelectServer(cmd, flags, default(RedisKey)); var bridge = selected?.GetBridge(ConnectionType.Subscription, true);
if (bridge == null) return null;
if (selected == null || Interlocked.CompareExchange(ref owner, selected, null) != null) return null; // note: check we can create the message validly *before* we swap the owner over (Interlocked)
var state = PendingSubscriptionState.Create(channel, this, flags, true, internalCall, asyncState, selected.IsSlave);
var msg = Message.Create(-1, flags, cmd, channel); if (Interlocked.CompareExchange(ref owner, selected, null) != null) return null;
if (internalCall) msg.SetInternalCall(); try
return selected.WriteDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState); {
if (!bridge.TryEnqueueBackgroundSubscriptionWrite(state))
{
state.Abort();
return null;
}
return state.Task;
}
catch
{
// clear the owner if it is still us
Interlocked.CompareExchange(ref owner, null, selected);
throw;
}
} }
public Task UnsubscribeFromServer(in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall) public Task UnsubscribeFromServer(in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{ {
var oldOwner = Interlocked.Exchange(ref owner, null); var oldOwner = Interlocked.Exchange(ref owner, null);
if (oldOwner == null) return null; var bridge = oldOwner?.GetBridge(ConnectionType.Subscription, false);
if (bridge == null) return null;
var state = PendingSubscriptionState.Create(channel, this, flags, false, internalCall, asyncState, oldOwner.IsSlave);
var cmd = channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE; if (!bridge.TryEnqueueBackgroundSubscriptionWrite(state))
{
state.Abort();
return null;
}
return state.Task;
}
internal readonly struct PendingSubscriptionState
{
public override string ToString() => Message.ToString();
public Subscription Subscription { get; }
public Message Message { get; }
public bool IsSlave { get; }
public Task Task => _taskSource.Task;
private readonly TaskCompletionSource<bool> _taskSource;
public static PendingSubscriptionState Create(RedisChannel channel, Subscription subscription, CommandFlags flags, bool subscribe, bool internalCall, object asyncState, bool isSlave)
=> new PendingSubscriptionState(asyncState, channel, subscription, flags, subscribe, internalCall, isSlave);
public void Abort() => _taskSource.TrySetCanceled();
public void Fail(Exception ex) => _taskSource.TrySetException(ex);
private PendingSubscriptionState(object asyncState, RedisChannel channel, Subscription subscription, CommandFlags flags, bool subscribe, bool internalCall, bool isSlave)
{
var cmd = subscribe
? (channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE)
: (channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE);
var msg = Message.Create(-1, flags, cmd, channel); var msg = Message.Create(-1, flags, cmd, channel);
if (internalCall) msg.SetInternalCall(); if (internalCall) msg.SetInternalCall();
return oldOwner.WriteDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState); var taskSource = TaskSource.Create<bool>(asyncState);
var source = ResultBox<bool>.Get(taskSource);
msg.SetSource(ResultProcessor.TrackSubscriptions, source);
Subscription = subscription;
_taskSource = taskSource;
Message = msg;
IsSlave = isSlave;
}
} }
internal ServerEndPoint GetOwner() => Volatile.Read(ref owner); internal ServerEndPoint GetOwner() => Volatile.Read(ref owner);
...@@ -225,7 +279,9 @@ internal void Resubscribe(in RedisChannel channel, ServerEndPoint server) ...@@ -225,7 +279,9 @@ internal void Resubscribe(in RedisChannel channel, ServerEndPoint server)
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE; var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel); var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel);
msg.SetInternalCall(); msg.SetInternalCall();
server.WriteDirectFireAndForget(msg, ResultProcessor.TrackSubscriptions); #pragma warning disable CS0618
server.WriteDirectFireAndForgetSync(msg, ResultProcessor.TrackSubscriptions);
#pragma warning restore CS0618
} }
} }
......
...@@ -287,7 +287,9 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -287,7 +287,9 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
sb.AppendLine("checking conditions in the *early* path"); sb.AppendLine("checking conditions in the *early* path");
// need to get those sent ASAP; if they are stuck in the buffers, we die // need to get those sent ASAP; if they are stuck in the buffers, we die
multiplexer.Trace("Flushing and waiting for precondition responses"); multiplexer.Trace("Flushing and waiting for precondition responses");
connection.FlushSync(true); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary) #pragma warning disable CS0618
connection.FlushSync(true, multiplexer.TimeoutMilliseconds); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary)
#pragma warning restore CS0618
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds)) if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{ {
...@@ -344,7 +346,9 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -344,7 +346,9 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
sb.AppendLine("checking conditions in the *late* path"); sb.AppendLine("checking conditions in the *late* path");
multiplexer.Trace("Flushing and waiting for precondition+queued responses"); multiplexer.Trace("Flushing and waiting for precondition+queued responses");
connection.FlushSync(true); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary) #pragma warning disable CS0618
connection.FlushSync(true, multiplexer.TimeoutMilliseconds); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary)
#pragma warning restore CS0618
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds)) if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{ {
if (!AreAllConditionsSatisfied(multiplexer)) if (!AreAllConditionsSatisfied(multiplexer))
......
...@@ -488,8 +488,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -488,8 +488,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
{ {
hash = ParseSHA1(asciiHash); // external caller wants the hex bytes, not the ascii bytes hash = ParseSHA1(asciiHash); // external caller wants the hex bytes, not the ascii bytes
} }
var sl = message as RedisDatabase.ScriptLoadMessage; if (message is RedisDatabase.ScriptLoadMessage sl)
if (sl != null)
{ {
connection.BridgeCouldBeNull?.ServerEndPoint?.AddScript(sl.Script, asciiHash); connection.BridgeCouldBeNull?.ServerEndPoint?.AddScript(sl.Script, asciiHash);
} }
...@@ -1413,7 +1412,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -1413,7 +1412,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
return false; return false;
} }
StreamEntry[] entries = null; StreamEntry[] entries;
if (skipStreamName) if (skipStreamName)
{ {
......
This diff is collapsed.
...@@ -61,10 +61,10 @@ public ServerSelectionStrategy(ConnectionMultiplexer multiplexer) ...@@ -61,10 +61,10 @@ public ServerSelectionStrategy(ConnectionMultiplexer multiplexer)
/// Computes the hash-slot that would be used by the given key /// Computes the hash-slot that would be used by the given key
/// </summary> /// </summary>
/// <param name="key">The <see cref="RedisKey"/> to determine a slot ID for.</param> /// <param name="key">The <see cref="RedisKey"/> to determine a slot ID for.</param>
public int HashSlot(RedisKey key) public int HashSlot(in RedisKey key)
=> ServerType == ServerType.Standalone ? NoSlot : GetClusterSlot(key); => ServerType == ServerType.Standalone ? NoSlot : GetClusterSlot(key);
private static unsafe int GetClusterSlot(RedisKey key) private static unsafe int GetClusterSlot(in RedisKey key)
{ {
//HASH_SLOT = CRC16(key) mod 16384 //HASH_SLOT = CRC16(key) mod 16384
if (key.IsNull) return NoSlot; if (key.IsNull) return NoSlot;
...@@ -107,7 +107,7 @@ public ServerEndPoint Select(Message message) ...@@ -107,7 +107,7 @@ public ServerEndPoint Select(Message message)
return Select(slot, message.Command, message.Flags); return Select(slot, message.Command, message.Flags);
} }
public ServerEndPoint Select(RedisCommand command, RedisKey key, CommandFlags flags) public ServerEndPoint Select(RedisCommand command, in RedisKey key, CommandFlags flags)
{ {
int slot = ServerType == ServerType.Cluster ? HashSlot(key) : NoSlot; int slot = ServerType == ServerType.Cluster ? HashSlot(key) : NoSlot;
return Select(slot, command, flags); return Select(slot, command, flags);
...@@ -155,7 +155,9 @@ public bool TryResend(int hashSlot, Message message, EndPoint endpoint, bool isM ...@@ -155,7 +155,9 @@ public bool TryResend(int hashSlot, Message message, EndPoint endpoint, bool isM
else else
{ {
message.PrepareToResend(resendVia, isMoved); message.PrepareToResend(resendVia, isMoved);
retry = resendVia.TryWrite(message) == WriteResult.Success; #pragma warning disable CS0618
retry = resendVia.TryWriteSync(message) == WriteResult.Success;
#pragma warning restore CS0618
} }
} }
...@@ -187,7 +189,7 @@ internal int CombineSlot(int oldSlot, int newSlot) ...@@ -187,7 +189,7 @@ internal int CombineSlot(int oldSlot, int newSlot)
return oldSlot == newSlot ? oldSlot : MultipleSlots; return oldSlot == newSlot ? oldSlot : MultipleSlots;
} }
internal int CombineSlot(int oldSlot, RedisKey key) internal int CombineSlot(int oldSlot, in RedisKey key)
{ {
if (oldSlot == MultipleSlots || key.IsNull) return oldSlot; if (oldSlot == MultipleSlots || key.IsNull) return oldSlot;
......
...@@ -26,6 +26,7 @@ public static Task<T> ObserveErrors<T>(this Task<T> task) ...@@ -26,6 +26,7 @@ public static Task<T> ObserveErrors<T>(this Task<T> task)
} }
public static ConfiguredTaskAwaitable ForAwait(this Task task) => task.ConfigureAwait(false); public static ConfiguredTaskAwaitable ForAwait(this Task task) => task.ConfigureAwait(false);
public static ConfiguredValueTaskAwaitable ForAwait(this ValueTask task) => task.ConfigureAwait(false);
public static ConfiguredTaskAwaitable<T> ForAwait<T>(this Task<T> task) => task.ConfigureAwait(false); public static ConfiguredTaskAwaitable<T> ForAwait<T>(this Task<T> task) => task.ConfigureAwait(false);
public static ConfiguredValueTaskAwaitable<T> ForAwait<T>(this ValueTask<T> task) => task.ConfigureAwait(false); public static ConfiguredValueTaskAwaitable<T> ForAwait<T>(this ValueTask<T> task) => task.ConfigureAwait(false);
...@@ -37,10 +38,10 @@ public static Task<T> ObserveErrors<T>(this Task<T> task) ...@@ -37,10 +38,10 @@ public static Task<T> ObserveErrors<T>(this Task<T> task)
public static async Task<bool> TimeoutAfter(this Task task, int timeoutMs) public static async Task<bool> TimeoutAfter(this Task task, int timeoutMs)
{ {
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
if (task == await Task.WhenAny(task, Task.Delay(timeoutMs, cts.Token)).ConfigureAwait(false)) if (task == await Task.WhenAny(task, Task.Delay(timeoutMs, cts.Token)).ForAwait())
{ {
cts.Cancel(); cts.Cancel();
await task.ConfigureAwait(false); await task.ForAwait();
return true; return true;
} }
else else
......
using System; using System;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Xunit; using Xunit;
using Xunit.Abstractions; using Xunit.Abstractions;
...@@ -31,5 +34,59 @@ public void SubscriberCount() ...@@ -31,5 +34,59 @@ public void SubscriberCount()
Assert.Contains(channel, channels); Assert.Contains(channel, channels);
} }
} }
[Fact]
public async Task SubscriberCountAsync()
{
using (var conn = Create())
{
RedisChannel channel = Me() + Guid.NewGuid();
var server = conn.GetServer(conn.GetEndPoints()[0]);
var channels = await server.SubscriptionChannelsAsync(Me() + "*").WithTimeout(2000);
Assert.DoesNotContain(channel, channels);
long justWork = await server.SubscriptionPatternCountAsync().WithTimeout(2000);
var count = await server.SubscriptionSubscriberCountAsync(channel).WithTimeout(2000);
Assert.Equal(0, count);
await conn.GetSubscriber().SubscribeAsync(channel, delegate { }).WithTimeout(2000);
count = await server.SubscriptionSubscriberCountAsync(channel).WithTimeout(2000);
Assert.Equal(1, count);
channels = await server.SubscriptionChannelsAsync(Me() + "*").WithTimeout(2000);
Assert.Contains(channel, channels);
}
}
}
static class Util
{
public static async Task WithTimeout(this Task task, int timeoutMs,
[CallerMemberName] string caller = null, [CallerLineNumber] int line = 0)
{
var cts = new CancellationTokenSource();
if (task == await Task.WhenAny(task, Task.Delay(timeoutMs, cts.Token)).ForAwait())
{
cts.Cancel();
await task.ForAwait();
}
else
{
throw new TimeoutException($"timout from {caller} line {line}");
}
}
public static async Task<T> WithTimeout<T>(this Task<T> task, int timeoutMs,
[CallerMemberName] string caller = null, [CallerLineNumber] int line = 0)
{
var cts = new CancellationTokenSource();
if (task == await Task.WhenAny(task, Task.Delay(timeoutMs, cts.Token)).ForAwait())
{
cts.Cancel();
return await task.ForAwait();
}
else
{
throw new TimeoutException($"timout from {caller} line {line}");
}
}
} }
} }
...@@ -8,55 +8,16 @@ namespace TestConsole ...@@ -8,55 +8,16 @@ namespace TestConsole
{ {
internal static class Program internal static class Program
{ {
private const int taskCount = 10; public static async Task Main()
private const int totalRecords = 100000;
private static void Main()
{
#if SEV2
Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies();
Console.WriteLine("We loaded the things...");
// Console.ReadLine();
#endif
Stopwatch stopwatch = new Stopwatch();
stopwatch.Start();
var taskList = new List<Task>();
var connection = ConnectionMultiplexer.Connect("127.0.0.1");
for (int i = 0; i < taskCount; i++)
{
var i1 = i;
var task = new Task(() => Run(i1, connection));
task.Start();
taskList.Add(task);
}
Task.WaitAll(taskList.ToArray());
stopwatch.Stop();
Console.WriteLine($"Done. {stopwatch.ElapsedMilliseconds}");
Console.ReadLine();
}
private static void Run(int taskId, ConnectionMultiplexer connection)
{
Console.WriteLine($"{taskId} Started");
var database = connection.GetDatabase(0);
for (int i = 0; i < totalRecords; i++)
{ {
database.StringSet(i.ToString(), i.ToString()); using (var muxer = await ConnectionMultiplexer.ConnectAsync("127.0.0.1"))
}
Console.WriteLine($"{taskId} Insert completed");
for (int i = 0; i < totalRecords; i++)
{ {
var result = database.StringGet(i.ToString()); var db = muxer.GetDatabase();
var sub = muxer.GetSubscriber();
Console.WriteLine("subscribing");
ChannelMessageQueue queue = await sub.SubscribeAsync("yolo");
Console.WriteLine("subscribed");
} }
Console.WriteLine($"{taskId} Completed");
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment