Commit 6af1811b authored by Marc Gravell's avatar Marc Gravell

all of the outbound changes

parent 5c31bb25
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<AssemblyOriginatorKeyFile>../StackExchange.Redis.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>../StackExchange.Redis.snk</AssemblyOriginatorKeyFile>
<PackageId>$(AssemblyName)</PackageId> <PackageId>$(AssemblyName)</PackageId>
<Authors>Stack Exchange, Inc.; marc.gravell</Authors> <Authors>Stack Exchange, Inc.; marc.gravell</Authors>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<PackageReleaseNotes>https://stackexchange.github.io/StackExchange.Redis/ReleaseNotes</PackageReleaseNotes> <PackageReleaseNotes>https://stackexchange.github.io/StackExchange.Redis/ReleaseNotes</PackageReleaseNotes>
<PackageProjectUrl>https://github.com/StackExchange/StackExchange.Redis/</PackageProjectUrl> <PackageProjectUrl>https://github.com/StackExchange/StackExchange.Redis/</PackageProjectUrl>
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
<DefaultLanguage>en-US</DefaultLanguage> <DefaultLanguage>en-US</DefaultLanguage>
<IncludeSymbols>false</IncludeSymbols> <IncludeSymbols>false</IncludeSymbols>
<LibraryTargetFrameworks>net45;net46;netstandard2.0</LibraryTargetFrameworks> <LibraryTargetFrameworks>net46;netstandard2.0</LibraryTargetFrameworks>
<CoreFxVersion>4.5.0</CoreFxVersion> <CoreFxVersion>4.5.0</CoreFxVersion>
<xUnitVersion>2.4.0-beta.2.build3981</xUnitVersion> <xUnitVersion>2.4.0-beta.2.build3981</xUnitVersion>
</PropertyGroup> </PropertyGroup>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
</ItemGroup> </ItemGroup>
<PropertyGroup Condition=" '$(TargetFramework)' == 'net45' or '$(TargetFramework)' == 'net46'"> <PropertyGroup Condition=" '$(TargetFramework)' == 'net45' or '$(TargetFramework)' == 'net46'">
<DefineConstants>$(DefineConstants);FEATURE_SOCKET_MODE_POLL;FEATURE_PERFCOUNTER;</DefineConstants> <DefineConstants>$(DefineConstants);FEATURE_PERFCOUNTER;</DefineConstants>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
...@@ -27,9 +27,9 @@ ...@@ -27,9 +27,9 @@
<PackageReference Include="System.Reflection.Emit.ILGeneration" Version="4.3.0" /> <PackageReference Include="System.Reflection.Emit.ILGeneration" Version="4.3.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" /> <PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<!--<PackageReference Include="System.Memory" Version="$(CoreFxVersion)" /> <PackageReference Include="System.Memory" Version="$(CoreFxVersion)" />
<PackageReference Include="System.Buffers" Version="$(CoreFxVersion)" /> <PackageReference Include="System.Buffers" Version="$(CoreFxVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxVersion)" /> <PackageReference Include="System.IO.Pipelines" Version="$(CoreFxVersion)" />
<PackageReference Include="Pipelines.Sockets.Unofficial" Version="0.2.0-alpha-001" />--> <PackageReference Include="Pipelines.Sockets.Unofficial" Version="0.2.0-alpha-004" />
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
...@@ -267,21 +268,19 @@ internal void KeepAlive() ...@@ -267,21 +268,19 @@ internal void KeepAlive()
} }
} }
internal void OnConnected(PhysicalConnection connection, TextWriter log) internal async Task OnConnectedAsync(PhysicalConnection connection, TextWriter log)
{ {
Trace("OnConnected"); Trace("OnConnected");
if (physical == connection && !isDisposed && ChangeState(State.Connecting, State.ConnectedEstablishing)) if (physical == connection && !isDisposed && ChangeState(State.Connecting, State.ConnectedEstablishing))
{ {
ServerEndPoint.OnEstablishing(connection, log); await ServerEndPoint.OnEstablishingAsync(connection, log);
} }
else else
{ {
try try
{ {
connection.Dispose(); connection.Dispose();
} } catch { }
catch
{ }
} }
} }
...@@ -552,7 +551,10 @@ internal bool WriteMessageDirect(PhysicalConnection tmp, Message next) ...@@ -552,7 +551,10 @@ internal bool WriteMessageDirect(PhysicalConnection tmp, Message next)
} }
} }
internal WriteResult WriteQueue(int maxWork) [MethodImpl(MethodImplOptions.AggressiveInlining)]
static ValueTask<T> AsResult<T>(T value) => new ValueTask<T>(value);
internal async ValueTask<WriteResult> WriteQueueAsync(int maxWork)
{ {
bool weAreWriter = false; bool weAreWriter = false;
PhysicalConnection conn = null; PhysicalConnection conn = null;
...@@ -585,7 +587,7 @@ internal WriteResult WriteQueue(int maxWork) ...@@ -585,7 +587,7 @@ internal WriteResult WriteQueue(int maxWork)
Trace("Nothing to write; exiting"); Trace("Nothing to write; exiting");
if(count == 0) if(count == 0)
{ {
conn.Flush(); // only flush on an empty run await conn.FlushAsync(); // only flush on an empty run
return WriteResult.NothingToDo; return WriteResult.NothingToDo;
} }
return WriteResult.QueueEmptyAfterWrite; return WriteResult.QueueEmptyAfterWrite;
...@@ -604,7 +606,7 @@ internal WriteResult WriteQueue(int maxWork) ...@@ -604,7 +606,7 @@ internal WriteResult WriteQueue(int maxWork)
{ {
Trace("Work limit; exiting"); Trace("Work limit; exiting");
Trace(last != null, "Flushed up to: " + last); Trace(last != null, "Flushed up to: " + last);
conn.Flush(); await conn.FlushAsync();
break; break;
} }
} }
......
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.IO; using System.IO;
using System.IO.Pipelines;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Net.Security; using System.Net.Security;
...@@ -10,6 +13,7 @@ ...@@ -10,6 +13,7 @@
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
...@@ -60,14 +64,10 @@ private static readonly Message ...@@ -60,14 +64,10 @@ private static readonly Message
private int failureReported; private int failureReported;
private byte[] ioBuffer = new byte[512];
private int ioBufferBytes = 0;
private int lastWriteTickCount, lastReadTickCount, lastBeatTickCount; private int lastWriteTickCount, lastReadTickCount, lastBeatTickCount;
private int firstUnansweredWriteTickCount; private int firstUnansweredWriteTickCount;
private Stream netStream, outStream; IDuplexPipe _ioPipe;
private SocketToken socketToken; private SocketToken socketToken;
...@@ -113,19 +113,19 @@ private enum ReadMode : byte ...@@ -113,19 +113,19 @@ private enum ReadMode : byte
public void Dispose() public void Dispose()
{ {
if (outStream != null)
var ioPipe = _ioPipe;
_ioPipe = null;
if(ioPipe != null)
{ {
Multiplexer.Trace("Disconnecting...", physicalName); Multiplexer.Trace("Disconnecting...", physicalName);
try { outStream.Close(); } catch { } try { ioPipe.Input?.CancelPendingRead(); } catch { }
try { outStream.Dispose(); } catch { } try { ioPipe.Input?.Complete(); } catch { }
outStream = null; try { ioPipe.Output?.CancelPendingFlush(); } catch { }
} try { ioPipe.Output?.Complete(); } catch { }
if (netStream != null) ioPipe.Output?.Complete();
{
try { netStream.Close(); } catch { }
try { netStream.Dispose(); } catch { }
netStream = null;
} }
if (socketToken.HasValue) if (socketToken.HasValue)
{ {
Multiplexer.SocketManager?.Shutdown(socketToken); Multiplexer.SocketManager?.Shutdown(socketToken);
...@@ -135,15 +135,21 @@ public void Dispose() ...@@ -135,15 +135,21 @@ public void Dispose()
} }
OnCloseEcho(); OnCloseEcho();
} }
private async Task AwaitedFlush(ValueTask<FlushResult> flush)
public void Flush()
{ {
var tmp = outStream; await flush;
Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount);
}
public Task FlushAsync()
{
var tmp = _ioPipe?.Output;
if (tmp != null) if (tmp != null)
{ {
tmp.Flush(); var flush = tmp.FlushAsync();
if (!flush.IsCompletedSuccessfully) return AwaitedFlush(flush);
Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount); Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount);
} }
return Task.CompletedTask;
} }
public void RecordConnectionFailed(ConnectionFailureType failureType, Exception innerException = null, [CallerMemberName] string origin = null) public void RecordConnectionFailed(ConnectionFailureType failureType, Exception innerException = null, [CallerMemberName] string origin = null)
...@@ -197,7 +203,7 @@ void add(string lk, string sk, string v) ...@@ -197,7 +203,7 @@ void add(string lk, string sk, string v)
} }
add("Origin", "origin", origin); add("Origin", "origin", origin);
add("Input-Buffer", "input-buffer", ioBufferBytes.ToString()); // add("Input-Buffer", "input-buffer", _ioPipe.Input);
add("Outstanding-Responses", "outstanding", GetSentAwaitingResponseCount().ToString()); add("Outstanding-Responses", "outstanding", GetSentAwaitingResponseCount().ToString());
add("Last-Read", "last-read", (unchecked(now - lastRead) / 1000) + "s ago"); add("Last-Read", "last-read", (unchecked(now - lastRead) / 1000) + "s ago");
add("Last-Write", "last-write", (unchecked(now - lastWrite) / 1000) + "s ago"); add("Last-Write", "last-write", (unchecked(now - lastWrite) / 1000) + "s ago");
...@@ -405,28 +411,28 @@ internal void Write(RedisKey key) ...@@ -405,28 +411,28 @@ internal void Write(RedisKey key)
var val = key.KeyValue; var val = key.KeyValue;
if (val is string) if (val is string)
{ {
WriteUnified(outStream, key.KeyPrefix, (string)val); WriteUnified(_ioPipe.Output, key.KeyPrefix, (string)val);
} }
else else
{ {
WriteUnified(outStream, key.KeyPrefix, (byte[])val); WriteUnified(_ioPipe.Output, key.KeyPrefix, (byte[])val);
} }
} }
internal void Write(RedisChannel channel) internal void Write(RedisChannel channel)
{ {
WriteUnified(outStream, ChannelPrefix, channel.Value); WriteUnified(_ioPipe.Output, ChannelPrefix, channel.Value);
} }
internal void Write(RedisValue value) internal void Write(RedisValue value)
{ {
if (value.IsInteger) if (value.IsInteger)
{ {
WriteUnified(outStream, (long)value); WriteUnified(_ioPipe.Output, (long)value);
} }
else else
{ {
WriteUnified(outStream, (byte[])value); WriteUnified(_ioPipe.Output, (byte[])value);
} }
} }
...@@ -437,15 +443,10 @@ internal void WriteHeader(RedisCommand command, int arguments) ...@@ -437,15 +443,10 @@ internal void WriteHeader(RedisCommand command, int arguments)
{ {
throw ExceptionFactory.CommandDisabled(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint); throw ExceptionFactory.CommandDisabled(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint);
} }
outStream.WriteByte((byte)'*'); WriteHeader(commandBytes, arguments);
// remember the time of the first write that still not followed by read
Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0);
WriteRaw(outStream, arguments + 1);
WriteUnified(outStream, commandBytes);
} }
internal const int REDIS_MAX_ARGS = 1024 * 1024; // there is a <= 1024*1024 max constraint inside redis itself: https://github.com/antirez/redis/blob/6c60526db91e23fb2d666fc52facc9a11780a2a3/src/networking.c#L1024 internal const int REDIS_MAX_ARGS = 1024 * 1024; // there is a <= 1024*1024 max constraint inside redis itself: https://github.com/antirez/redis/blob/6c60526db91e23fb2d666fc52facc9a11780a2a3/src/networking.c#L1024
internal void WriteHeader(string command, int arguments) internal void WriteHeader(string command, int arguments)
...@@ -455,39 +456,66 @@ internal void WriteHeader(string command, int arguments) ...@@ -455,39 +456,66 @@ internal void WriteHeader(string command, int arguments)
throw ExceptionFactory.TooManyArgs(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint, arguments + 1); throw ExceptionFactory.TooManyArgs(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint, arguments + 1);
} }
var commandBytes = Multiplexer.CommandMap.GetBytes(command); var commandBytes = Multiplexer.CommandMap.GetBytes(command);
if (commandBytes == null) WriteHeader(commandBytes, arguments);
{ }
throw ExceptionFactory.CommandDisabled(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint); private void WriteHeader(byte[] commandBytes, int arguments)
} {
outStream.WriteByte((byte)'*');
// remember the time of the first write that still not followed by read // remember the time of the first write that still not followed by read
Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0); Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0);
WriteRaw(outStream, arguments + 1); // *{argCount}\r\n = 3 + MaxInt32TextLen
WriteUnified(outStream, commandBytes); // ${cmd-len}\r\n = 3 + MaxInt32TextLen
// {cmd}\r\n = 2 + commandBytes.Length
var span = _ioPipe.Output.GetSpan(commandBytes.Length + 8 + MaxInt32TextLen + MaxInt32TextLen);
span[0] = (byte)'*';
int offset = WriteRaw(span, arguments + 1, offset: 1);
offset = WriteUnified(span, commandBytes, offset: offset);
_ioPipe.Output.Advance(offset);
} }
const int MaxInt32TextLen = 11, // -2,147,483,648 (not including the commas)
MaxInt64TextLen = 20; // -9,223,372,036,854,775,808 (not including the commas)
private static void WriteRaw(Stream stream, long value, bool withLengthPrefix = false) [MethodImpl(MethodImplOptions.AggressiveInlining)]
static int WriteCrlf(Span<byte> span, int offset)
{
span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n';
return offset;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static void WriteCrlf(PipeWriter writer)
{
var span = writer.GetSpan(2);
span[0] = (byte)'\r';
span[1] = (byte)'\n';
writer.Advance(2);
}
private static int WriteRaw(Span<byte> span, long value, bool withLengthPrefix = false, int offset = 0)
{ {
if (value >= 0 && value <= 9) if (value >= 0 && value <= 9)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'1'); span[offset++] = (byte)'1';
stream.Write(Crlf, 0, 2); offset = WriteCrlf(span, offset);
} }
stream.WriteByte((byte)((int)'0' + (int)value)); span[offset++] = (byte)((int)'0' + (int)value);
} }
else if (value >= 10 && value < 100) else if (value >= 10 && value < 100)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'2'); span[offset++] = (byte)'2';
stream.Write(Crlf, 0, 2); offset = WriteCrlf(span, offset);
} }
stream.WriteByte((byte)((int)'0' + ((int)value / 10))); span[offset++] = (byte)((int)'0' + ((int)value / 10));
stream.WriteByte((byte)((int)'0' + ((int)value % 10))); span[offset++] = (byte)((int)'0' + ((int)value % 10));
} }
else if (value >= 100 && value < 1000) else if (value >= 100 && value < 1000)
{ {
...@@ -497,79 +525,143 @@ private static void WriteRaw(Stream stream, long value, bool withLengthPrefix = ...@@ -497,79 +525,143 @@ private static void WriteRaw(Stream stream, long value, bool withLengthPrefix =
int tens = v % 10, hundreds = v / 10; int tens = v % 10, hundreds = v / 10;
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'3'); span[offset++] = (byte)'3';
stream.Write(Crlf, 0, 2); offset = WriteCrlf(span, offset);
} }
stream.WriteByte((byte)((int)'0' + hundreds)); span[offset++] = (byte)((int)'0' + hundreds);
stream.WriteByte((byte)((int)'0' + tens)); span[offset++] = (byte)((int)'0' + tens);
stream.WriteByte((byte)((int)'0' + units)); span[offset++] = (byte)((int)'0' + units);
} }
else if (value < 0 && value >= -9) else if (value < 0 && value >= -9)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'2'); span[offset++] = (byte)'2';
stream.Write(Crlf, 0, 2); offset = WriteCrlf(span, offset);
} }
stream.WriteByte((byte)'-'); span[offset++] = (byte)'-';
stream.WriteByte((byte)((int)'0' - (int)value)); span[offset++] = (byte)((int)'0' - (int)value);
} }
else if (value <= -10 && value > -100) else if (value <= -10 && value > -100)
{ {
if (withLengthPrefix) if (withLengthPrefix)
{ {
stream.WriteByte((byte)'3'); span[offset++] = (byte)'3';
stream.Write(Crlf, 0, 2); offset = WriteCrlf(span, offset);
} }
value = -value; value = -value;
stream.WriteByte((byte)'-'); span[offset++] = (byte)'-';
stream.WriteByte((byte)((int)'0' + ((int)value / 10))); span[offset++] = (byte)((int)'0' + ((int)value / 10));
stream.WriteByte((byte)((int)'0' + ((int)value % 10))); span[offset++] = (byte)((int)'0' + ((int)value % 10));
} }
else else
{ {
var bytes = Encoding.ASCII.GetBytes(Format.ToString(value)); unsafe
if (withLengthPrefix)
{ {
WriteRaw(stream, bytes.Length, false); byte* bytes = stackalloc byte[MaxInt32TextLen];
var s = Format.ToString(value); // need an alloc-free version of this...
int len;
fixed (char* c = s)
{
len = Encoding.ASCII.GetBytes(c, s.Length, bytes, MaxInt32TextLen);
}
if (withLengthPrefix)
{
offset = WriteRaw(span, len, false, offset);
}
new ReadOnlySpan<byte>(bytes, len).CopyTo(span.Slice(offset));
offset += len;
} }
stream.Write(bytes, 0, bytes.Length);
} }
stream.Write(Crlf, 0, 2); return WriteCrlf(span, offset);
} }
private static void WriteUnified(Stream stream, byte[] value) static readonly byte[] NullBulkString = Encoding.ASCII.GetBytes("$-1\r\n"), EmptyBulkString = Encoding.ASCII.GetBytes("$0\r\n\r\n");
private static void WriteUnified(PipeWriter writer, byte[] value)
{ {
stream.WriteByte((byte)'$'); const int MaxQuickSpanSize = 512;
// ${len}\r\n = 3 + MaxInt32TextLen
// {value}\r\n = 2 + value.Length
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); // note that not many things like this... // special case:
writer.Write(NullBulkString);
}
else if (value.Length == 0)
{
// special case:
writer.Write(EmptyBulkString);
}
else if (value.Length <= MaxQuickSpanSize)
{
var span = writer.GetSpan(5 + MaxInt32TextLen + value.Length);
int bytes = WriteUnified(span, value);
writer.Advance(bytes);
} }
else else
{ {
WriteRaw(stream, value.Length); // too big to guarantee can do in a single span
stream.Write(value, 0, value.Length); var span = writer.GetSpan(3 + MaxInt32TextLen);
stream.Write(Crlf, 0, 2); span[0] = (byte)'$';
int bytes = WriteRaw(span, value.LongLength, offset: 1);
writer.Advance(bytes);
writer.Write(value);
WriteCrlf(writer);
} }
} }
private static int WriteUnified(Span<byte> span, byte[] value, int offset = 0)
internal void WriteAsHex(byte[] value)
{ {
var stream = outStream; span[offset++] = (byte)'$';
stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); offset = WriteRaw(span, -1, offset: offset); // note that not many things like this...
} }
else else
{ {
WriteRaw(stream, value.Length * 2); offset = WriteRaw(span, value.Length, offset: offset);
for (int i = 0; i < value.Length; i++) new ReadOnlySpan<byte>(value).CopyTo(span.Slice(offset));
offset = WriteCrlf(span, offset);
}
return offset;
}
internal void WriteSha1AsHex(byte[] value)
{
var writer = _ioPipe.Output;
if (value == null)
{
writer.Write(NullBulkString);
}
else if(value.Length == ResultProcessor.ScriptLoadProcessor.Sha1HashLength)
{
// $40\r\n = 5
// {40 bytes}\r\n = 42
var span = writer.GetSpan(47);
span[0] = (byte)'$';
span[1] = (byte)'4';
span[2] = (byte)'0';
span[3] = (byte)'\r';
span[4] = (byte)'\n';
int offset = 5;
for(int i = 0; i < value.Length; i++)
{ {
stream.WriteByte(ToHexNibble(value[i] >> 4)); var b = value[i];
stream.WriteByte(ToHexNibble(value[i] & 15)); span[offset++] = ToHexNibble(value[i] >> 4);
span[offset++] = ToHexNibble(value[i] & 15);
} }
stream.Write(Crlf, 0, 2); span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n';
writer.Advance(offset);
}
else
{
throw new InvalidOperationException("Invalid SHA1 length: " + value.Length);
} }
} }
...@@ -578,90 +670,119 @@ internal static byte ToHexNibble(int value) ...@@ -578,90 +670,119 @@ internal static byte ToHexNibble(int value)
return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value); return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value);
} }
private void WriteUnified(Stream stream, byte[] prefix, string value) private void WriteUnified(PipeWriter writer, byte[] prefix, string value)
{ {
stream.WriteByte((byte)'$');
if (value == null) if (value == null)
{ {
WriteRaw(stream, -1); // note that not many things like this... // special case
writer.Write(NullBulkString);
} }
else else
{ {
int encodedLength = Encoding.UTF8.GetByteCount(value); // ${total-len}\r\n 3 + MaxInt32TextLen
if (prefix == null) // {prefix}{value}\r\n
int encodedLength = Encoding.UTF8.GetByteCount(value),
prefixLength = prefix == null ? 0 : prefix.Length,
totalLength = prefixLength + encodedLength;
if (totalLength == 0)
{ {
WriteRaw(stream, encodedLength); // special-case
WriteRaw(stream, value, encodedLength); writer.Write(EmptyBulkString);
stream.Write(Crlf, 0, 2);
} }
else else
{ {
WriteRaw(stream, prefix.Length + encodedLength); var span = writer.GetSpan(3 + MaxInt32TextLen);
stream.Write(prefix, 0, prefix.Length); span[0] = (byte)'$';
WriteRaw(stream, value, encodedLength); int bytes = WriteRaw(span, totalLength, offset: 1);
stream.Write(Crlf, 0, 2); writer.Advance(bytes);
if (prefixLength != 0) writer.Write(prefix);
if (encodedLength != 0) WriteRaw(writer, value, encodedLength);
WriteCrlf(writer);
} }
} }
} }
private unsafe void WriteRaw(Stream stream, string value, int encodedLength) private unsafe void WriteRaw(PipeWriter writer, string value, int encodedLength)
{ {
if (encodedLength <= ScratchSize) const int MaxQuickEncodeSize = 512;
{
int bytes = Encoding.UTF8.GetBytes(value, 0, value.Length, outScratch, 0); fixed (char* cPtr = value)
stream.Write(outScratch, 0, bytes);
}
else
{ {
fixed (char* c = value) int totalBytes;
fixed (byte* b = outScratch) if (encodedLength <= MaxQuickEncodeSize)
{
// encode directly in one hit
var span = writer.GetSpan(encodedLength);
fixed (byte* bPtr = &span[0])
{
totalBytes = Encoding.UTF8.GetBytes(cPtr, value.Length, bPtr, encodedLength);
}
writer.Advance(encodedLength);
}
else
{ {
int charsRemaining = value.Length, charOffset = 0, bytesWritten; // use an encoder in a loop
while (charsRemaining > Scratch_CharsPerBlock) outEncoder.Reset();
int charsRemaining = value.Length, charOffset = 0;
totalBytes = 0;
while (charsRemaining != 0)
{ {
bytesWritten = outEncoder.GetBytes(c + charOffset, Scratch_CharsPerBlock, b, ScratchSize, false); // note: at most 4 bytes per UTF8 character, despite what UTF8.GetMaxByteCount says
stream.Write(outScratch, 0, bytesWritten); var span = writer.GetSpan(4); // get *some* memory - at least enough for 1 character (but hopefully lots more)
charOffset += Scratch_CharsPerBlock; int bytesWritten, charsToWrite = span.Length >> 2; // assume worst case, because the API sucks
charsRemaining -= Scratch_CharsPerBlock; fixed (byte* bPtr = &span[0])
{
bytesWritten = outEncoder.GetBytes(cPtr + charOffset, charsToWrite, bPtr, span.Length, false);
}
writer.Advance(bytesWritten);
totalBytes += bytesWritten;
charOffset += charsToWrite;
charsRemaining -= charsRemaining;
} }
bytesWritten = outEncoder.GetBytes(c + charOffset, charsRemaining, b, ScratchSize, true);
if (bytesWritten != 0) stream.Write(outScratch, 0, bytesWritten);
} }
Debug.Assert(totalBytes == encodedLength);
} }
} }
private const int ScratchSize = 512;
private static readonly int Scratch_CharsPerBlock = ScratchSize / Encoding.UTF8.GetMaxByteCount(1);
private readonly byte[] outScratch = new byte[ScratchSize];
private readonly Encoder outEncoder = Encoding.UTF8.GetEncoder(); private readonly Encoder outEncoder = Encoding.UTF8.GetEncoder();
private static void WriteUnified(Stream stream, byte[] prefix, byte[] value) private static void WriteUnified(PipeWriter writer, byte[] prefix, byte[] value)
{ {
stream.WriteByte((byte)'$'); // ${total-len}\r\n
if (value == null) // {prefix}{value}\r\n
{ if (prefix == null || prefix.Length == 0 || value == null)
WriteRaw(stream, -1); // note that not many things like this... { // if no prefix, just use the non-prefixed version;
} // even if prefixed, a null value writes as null, so can use the non-prefixed version
else if (prefix == null) WriteUnified(writer, value);
{
WriteRaw(stream, value.Length);
stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2);
} }
else else
{ {
WriteRaw(stream, prefix.Length + value.Length); var span = writer.GetSpan(3 + MaxInt32TextLen); // note even with 2 max-len, we're still in same text range
stream.Write(prefix, 0, prefix.Length); span[0] = (byte)'$';
stream.Write(value, 0, value.Length); int bytes = WriteRaw(span, prefix.LongLength + value.LongLength, offset: 1);
stream.Write(Crlf, 0, 2); writer.Advance(bytes);
writer.Write(prefix);
writer.Write(value);
span = writer.GetSpan(2);
WriteCrlf(span, 0);
writer.Advance(2);
} }
} }
private static void WriteUnified(Stream stream, long value) private static void WriteUnified(PipeWriter writer, long value)
{ {
// note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings. // note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings.
// (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n" // (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n"
stream.WriteByte((byte)'$');
WriteRaw(stream, value, withLengthPrefix: true); // ${asc-len}\r\n = 3 + MaxInt32TextLen
// {asc}\r\n = MaxInt64TextLen + 2
var span = writer.GetSpan(5 + MaxInt32TextLen + MaxInt64TextLen);
span[0] = (byte)'$';
var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1);
writer.Advance(bytes);
} }
private void BeginReading() private void BeginReading()
...@@ -720,7 +841,7 @@ private static LocalCertificateSelectionCallback GetAmbientCertificateCallback() ...@@ -720,7 +841,7 @@ private static LocalCertificateSelectionCallback GetAmbientCertificateCallback()
return null; return null;
} }
SocketMode ISocketCallback.Connected(Stream stream, TextWriter log) async ValueTask<SocketMode> ISocketCallback.ConnectedAsync(IDuplexPipe pipe, TextWriter log)
{ {
try try
{ {
...@@ -735,36 +856,37 @@ SocketMode ISocketCallback.Connected(Stream stream, TextWriter log) ...@@ -735,36 +856,37 @@ SocketMode ISocketCallback.Connected(Stream stream, TextWriter log)
if (config.Ssl) if (config.Ssl)
{ {
Multiplexer.LogLocked(log, "Configuring SSL"); throw new NotImplementedException("TLS");
var host = config.SslHost; //Multiplexer.LogLocked(log, "Configuring SSL");
if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(Bridge.ServerEndPoint.EndPoint); //var host = config.SslHost;
//if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(Bridge.ServerEndPoint.EndPoint);
var ssl = new SslStream(stream, false, config.CertificateValidationCallback,
config.CertificateSelectionCallback ?? GetAmbientCertificateCallback(), //var ssl = new SslStream(stream, false, config.CertificateValidationCallback,
EncryptionPolicy.RequireEncryption); // config.CertificateSelectionCallback ?? GetAmbientCertificateCallback(),
try // EncryptionPolicy.RequireEncryption);
{ //try
ssl.AuthenticateAsClient(host, config.SslProtocols); //{
// ssl.AuthenticateAsClient(host, config.SslProtocols);
Multiplexer.LogLocked(log, $"SSL connection established successfully using protocol: {ssl.SslProtocol}");
} // Multiplexer.LogLocked(log, $"SSL connection established successfully using protocol: {ssl.SslProtocol}");
catch (AuthenticationException authexception) //}
{ //catch (AuthenticationException authexception)
RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, authexception); //{
Multiplexer.Trace("Encryption failure"); // RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, authexception);
return SocketMode.Abort; // Multiplexer.Trace("Encryption failure");
} // return SocketMode.Abort;
stream = ssl; //}
socketMode = SocketMode.Async; //stream = ssl;
//socketMode = SocketMode.Async;
} }
OnWrapForLogging(ref stream, physicalName); OnWrapForLogging(ref pipe, physicalName);
int bufferSize = config.WriteBuffer; int bufferSize = config.WriteBuffer;
netStream = stream;
outStream = bufferSize <= 0 ? stream : new BufferedStream(stream, bufferSize); _ioPipe = pipe;
Multiplexer.LogLocked(log, "Connected {0}", Bridge); Multiplexer.LogLocked(log, "Connected {0}", Bridge);
Bridge.OnConnected(this, log); await Bridge.OnConnectedAsync(this, log);
return socketMode; return socketMode;
} }
catch (Exception ex) catch (Exception ex)
...@@ -884,7 +1006,7 @@ void ISocketCallback.OnHeartbeat() ...@@ -884,7 +1006,7 @@ void ISocketCallback.OnHeartbeat()
} }
} }
partial void OnWrapForLogging(ref Stream stream, string name); partial void OnWrapForLogging(ref IDuplexPipe pipe, string name);
private int ProcessBuffer(byte[] underlying, ref int offset, ref int count) private int ProcessBuffer(byte[] underlying, ref int offset, ref int count)
{ {
int messageCount = 0; int messageCount = 0;
......
using System; //using System;
namespace StackExchange.Redis //namespace StackExchange.Redis
{ //{
internal static class PlatformHelper // internal static class PlatformHelper
{ // {
public static bool IsMono { get; } = Type.GetType("Mono.Runtime") != null; // public static bool IsMono { get; } = Type.GetType("Mono.Runtime") != null;
public static bool IsUnix { get; } = (int)Environment.OSVersion.Platform == 4 // public static bool IsUnix { get; } = (int)Environment.OSVersion.Platform == 4
|| (int)Environment.OSVersion.Platform == 6 // || (int)Environment.OSVersion.Platform == 6
|| (int)Environment.OSVersion.Platform == 128; // || (int)Environment.OSVersion.Platform == 128;
public static SocketMode DefaultSocketMode = IsMono && IsUnix ? SocketMode.Async : SocketMode.Poll; // public static SocketMode DefaultSocketMode = IsMono && IsUnix ? SocketMode.Async : SocketMode.Poll;
} // }
} //}
...@@ -2521,6 +2521,7 @@ public ScriptEvalMessage(int db, CommandFlags flags, byte[] hash, RedisKey[] key ...@@ -2521,6 +2521,7 @@ public ScriptEvalMessage(int db, CommandFlags flags, byte[] hash, RedisKey[] key
: this(db, flags, RedisCommand.EVAL, null, hash, keys, values) : this(db, flags, RedisCommand.EVAL, null, hash, keys, values)
{ {
if (hash == null) throw new ArgumentNullException(nameof(hash)); if (hash == null) throw new ArgumentNullException(nameof(hash));
if (hash.Length != ResultProcessor.ScriptLoadProcessor.Sha1HashLength) throw new ArgumentOutOfRangeException(nameof(hash), "Invalid hash length");
} }
private ScriptEvalMessage(int db, CommandFlags flags, RedisCommand command, string script, byte[] hexHash, RedisKey[] keys, RedisValue[] values) private ScriptEvalMessage(int db, CommandFlags flags, RedisCommand command, string script, byte[] hexHash, RedisKey[] keys, RedisValue[] values)
...@@ -2571,7 +2572,7 @@ internal override void WriteImpl(PhysicalConnection physical) ...@@ -2571,7 +2572,7 @@ internal override void WriteImpl(PhysicalConnection physical)
if (hexHash != null) if (hexHash != null)
{ {
physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length); physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length);
physical.WriteAsHex(hexHash); physical.WriteSha1AsHex(hexHash);
} }
else if (asciiHash != null) else if (asciiHash != null)
{ {
......
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
...@@ -250,7 +250,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -250,7 +250,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{ {
// need to get those sent ASAP; if they are stuck in the buffers, we die // need to get those sent ASAP; if they are stuck in the buffers, we die
multiplexer.Trace("Flushing and waiting for precondition responses"); multiplexer.Trace("Flushing and waiting for precondition responses");
connection.Flush(); connection.FlushAsync().Wait();
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds)) if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{ {
if (!AreAllConditionsSatisfied(multiplexer)) if (!AreAllConditionsSatisfied(multiplexer))
...@@ -297,7 +297,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection) ...@@ -297,7 +297,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
if (explicitCheckForQueued && lastBox != null) if (explicitCheckForQueued && lastBox != null)
{ {
multiplexer.Trace("Flushing and waiting for precondition+queued responses"); multiplexer.Trace("Flushing and waiting for precondition+queued responses");
connection.Flush(); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary) connection.FlushAsync().Wait(); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary)
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds)) if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{ {
if (!AreAllConditionsSatisfied(multiplexer)) if (!AreAllConditionsSatisfied(multiplexer))
...@@ -475,4 +475,4 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -475,4 +475,4 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
// return ExecuteTransactionAsync(flags); // return ExecuteTransactionAsync(flags);
// } // }
//} //}
} }
\ No newline at end of file
...@@ -393,11 +393,13 @@ internal static bool IsSHA1(string script) ...@@ -393,11 +393,13 @@ internal static bool IsSHA1(string script)
return script != null && sha1.IsMatch(script); return script != null && sha1.IsMatch(script);
} }
internal const int Sha1HashLength = 20;
internal static byte[] ParseSHA1(byte[] value) internal static byte[] ParseSHA1(byte[] value)
{ {
if (value?.Length == 40) if (value?.Length == Sha1HashLength * 2)
{ {
var tmp = new byte[20]; var tmp = new byte[Sha1HashLength];
int charIndex = 0; int charIndex = 0;
for (int i = 0; i < tmp.Length; i++) for (int i = 0; i < tmp.Length; i++)
{ {
...@@ -412,9 +414,9 @@ internal static byte[] ParseSHA1(byte[] value) ...@@ -412,9 +414,9 @@ internal static byte[] ParseSHA1(byte[] value)
internal static byte[] ParseSHA1(string value) internal static byte[] ParseSHA1(string value)
{ {
if (value?.Length == 40 && sha1.IsMatch(value)) if (value?.Length == (Sha1HashLength * 2) && sha1.IsMatch(value))
{ {
var tmp = new byte[20]; var tmp = new byte[Sha1HashLength];
int charIndex = 0; int charIndex = 0;
for (int i = 0; i < tmp.Length; i++) for (int i = 0; i < tmp.Length; i++)
{ {
...@@ -442,7 +444,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -442,7 +444,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
{ {
case ResultType.BulkString: case ResultType.BulkString:
var asciiHash = result.GetBlob(); var asciiHash = result.GetBlob();
if (asciiHash == null || asciiHash.Length != 40) return false; if (asciiHash == null || asciiHash.Length != (Sha1HashLength * 2)) return false;
byte[] hash = null; byte[] hash = null;
if (!message.IsInternalCall) if (!message.IsInternalCall)
......
...@@ -450,19 +450,33 @@ internal bool IsSelectable(RedisCommand command) ...@@ -450,19 +450,33 @@ internal bool IsSelectable(RedisCommand command)
return bridge?.IsConnected == true; return bridge?.IsConnected == true;
} }
internal void OnEstablishing(PhysicalConnection connection, TextWriter log) internal Task OnEstablishingAsync(PhysicalConnection connection, TextWriter log)
{ {
try try
{ {
if (connection == null) return; if (connection == null) return Task.CompletedTask;
Handshake(connection, log); var handshake = HandshakeAsync(connection, log);
if (handshake.Status != TaskStatus.RanToCompletion)
return OnEstablishingAsyncAwaited(connection, handshake);
}
catch (Exception ex)
{
connection.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
}
return Task.CompletedTask;
}
async Task OnEstablishingAsyncAwaited(PhysicalConnection connection, Task handshake)
{
try
{
await handshake;
} }
catch (Exception ex) catch (Exception ex)
{ {
connection.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex); connection.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
} }
} }
internal void OnFullyEstablished(PhysicalConnection connection) internal void OnFullyEstablished(PhysicalConnection connection)
{ {
try try
...@@ -627,13 +641,13 @@ private PhysicalBridge CreateBridge(ConnectionType type, TextWriter log) ...@@ -627,13 +641,13 @@ private PhysicalBridge CreateBridge(ConnectionType type, TextWriter log)
return bridge; return bridge;
} }
private void Handshake(PhysicalConnection connection, TextWriter log) private Task HandshakeAsync(PhysicalConnection connection, TextWriter log)
{ {
Multiplexer.LogLocked(log, "Server handshake"); Multiplexer.LogLocked(log, "Server handshake");
if (connection == null) if (connection == null)
{ {
Multiplexer.Trace("No connection!?"); Multiplexer.Trace("No connection!?");
return; return Task.CompletedTask;
} }
Message msg; Message msg;
string password = Multiplexer.RawConfig.Password; string password = Multiplexer.RawConfig.Password;
...@@ -684,7 +698,7 @@ private void Handshake(PhysicalConnection connection, TextWriter log) ...@@ -684,7 +698,7 @@ private void Handshake(PhysicalConnection connection, TextWriter log)
} }
} }
Multiplexer.LogLocked(log, "Flushing outbound buffer"); Multiplexer.LogLocked(log, "Flushing outbound buffer");
connection.Flush(); return connection.FlushAsync();
} }
private void SetConfig<T>(ref T field, T value, [CallerMemberName] string caller = null) private void SetConfig<T>(ref T field, T value, [CallerMemberName] string caller = null)
......
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.IO.Pipelines;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis namespace StackExchange.Redis
{ {
internal enum SocketMode internal enum SocketMode
{ {
Abort, Abort,
[Obsolete("just don't", error: true)]
Poll, Poll,
Async Async
} }
...@@ -22,9 +26,9 @@ internal partial interface ISocketCallback ...@@ -22,9 +26,9 @@ internal partial interface ISocketCallback
/// <summary> /// <summary>
/// Indicates that a socket has connected /// Indicates that a socket has connected
/// </summary> /// </summary>
/// <param name="stream">The network stream for this socket.</param> /// <param name="pipe">The network stream for this socket.</param>
/// <param name="log">A text logger to write to.</param> /// <param name="log">A text logger to write to.</param>
SocketMode Connected(Stream stream, TextWriter log); ValueTask<SocketMode> ConnectedAsync(IDuplexPipe pipe, TextWriter log);
/// <summary> /// <summary>
/// Indicates that the socket has signalled an error condition /// Indicates that the socket has signalled an error condition
...@@ -104,16 +108,6 @@ internal enum ManagerState ...@@ -104,16 +108,6 @@ internal enum ManagerState
ProcessErrorQueue, ProcessErrorQueue,
} }
private static readonly ParameterizedThreadStart writeAllQueues = context =>
{
try { ((SocketManager)context).WriteAllQueues(); } catch { }
};
private static readonly WaitCallback writeOneQueue = context =>
{
try { ((SocketManager)context).WriteOneQueue(); } catch { }
};
private readonly Queue<PhysicalBridge> writeQueue = new Queue<PhysicalBridge>(); private readonly Queue<PhysicalBridge> writeQueue = new Queue<PhysicalBridge>();
private bool isDisposed; private bool isDisposed;
private readonly bool useHighPrioritySocketThreads = true; private readonly bool useHighPrioritySocketThreads = true;
...@@ -140,14 +134,12 @@ public SocketManager(string name, bool useHighPrioritySocketThreads) ...@@ -140,14 +134,12 @@ public SocketManager(string name, bool useHighPrioritySocketThreads)
Name = name; Name = name;
this.useHighPrioritySocketThreads = useHighPrioritySocketThreads; this.useHighPrioritySocketThreads = useHighPrioritySocketThreads;
// we need a dedicated writer, because when under heavy ambient load _writeOneQueueAsync = () => WriteOneQueueAsync();
// (a busy asp.net site, for example), workers are not reliable enough
var dedicatedWriter = new Thread(writeAllQueues, 32 * 1024); // don't need a huge stack; Task.Run(() => WriteAllQueuesAsync());
dedicatedWriter.Priority = useHighPrioritySocketThreads ? ThreadPriority.AboveNormal : ThreadPriority.Normal;
dedicatedWriter.Name = name + ":Write";
dedicatedWriter.IsBackground = true; // should not keep process alive
dedicatedWriter.Start(this); // will self-exit when disposed
} }
private readonly Func<Task> _writeOneQueueAsync;
private enum CallbackOperation private enum CallbackOperation
{ {
...@@ -171,54 +163,16 @@ public void Dispose() ...@@ -171,54 +163,16 @@ public void Dispose()
internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback, ConnectionMultiplexer multiplexer, TextWriter log) internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback, ConnectionMultiplexer multiplexer, TextWriter log)
{ {
void RunWithCompletionType(Func<AsyncCallback, IAsyncResult> beginAsync, AsyncCallback asyncCallback)
{
void proxyCallback(IAsyncResult ar)
{
if (!ar.CompletedSynchronously)
{
asyncCallback(ar);
}
}
var result = beginAsync(proxyCallback);
if (result.CompletedSynchronously)
{
result.AsyncWaitHandle.WaitOne();
asyncCallback(result);
}
}
var addressFamily = endpoint.AddressFamily == AddressFamily.Unspecified ? AddressFamily.InterNetwork : endpoint.AddressFamily; var addressFamily = endpoint.AddressFamily == AddressFamily.Unspecified ? AddressFamily.InterNetwork : endpoint.AddressFamily;
var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
SetFastLoopbackOption(socket);
socket.NoDelay = true;
try try
{ {
var formattedEndpoint = Format.ToString(endpoint); var formattedEndpoint = Format.ToString(endpoint);
var tuple = Tuple.Create(socket, callback);
multiplexer.LogLocked(log, "BeginConnect: {0}", formattedEndpoint); SocketConnection.ConnectAsync(endpoint, PipeOptions.Default,
// A work-around for a Mono bug in BeginConnect(EndPoint endpoint, AsyncCallback callback, object state) conn => EndConnectAsync(conn, multiplexer, log, callback), socket);
if (endpoint is DnsEndPoint dnsEndpoint)
{
RunWithCompletionType(
cb => socket.BeginConnect(dnsEndpoint.Host, dnsEndpoint.Port, cb, tuple),
ar => {
multiplexer.LogLocked(log, "EndConnect: {0}", formattedEndpoint);
EndConnectImpl(ar, multiplexer, log, tuple);
multiplexer.LogLocked(log, "Connect complete: {0}", formattedEndpoint);
});
}
else
{
RunWithCompletionType(
cb => socket.BeginConnect(endpoint, cb, tuple),
ar => {
multiplexer.LogLocked(log, "EndConnect: {0}", formattedEndpoint);
EndConnectImpl(ar, multiplexer, log, tuple);
multiplexer.LogLocked(log, "Connect complete: {0}", formattedEndpoint);
});
}
} }
catch (NotImplementedException ex) catch (NotImplementedException ex)
{ {
...@@ -232,25 +186,7 @@ void proxyCallback(IAsyncResult ar) ...@@ -232,25 +186,7 @@ void proxyCallback(IAsyncResult ar)
return token; return token;
} }
internal void SetFastLoopbackOption(Socket socket)
{
// SIO_LOOPBACK_FAST_PATH (https://msdn.microsoft.com/en-us/library/windows/desktop/jj841212%28v=vs.85%29.aspx)
// Speeds up localhost operations significantly. OK to apply to a socket that will not be hooked up to localhost,
// or will be subject to WFP filtering.
const int SIO_LOOPBACK_FAST_PATH = -1744830448;
// windows only
if (Environment.OSVersion.Platform == PlatformID.Win32NT)
{
// Win8/Server2012+ only
var osVersion = Environment.OSVersion.Version;
if (osVersion.Major > 6 || (osVersion.Major == 6 && osVersion.Minor >= 2))
{
byte[] optionInValue = BitConverter.GetBytes(1);
socket.IOControl(SIO_LOOPBACK_FAST_PATH, optionInValue, null);
}
}
}
internal void RequestWrite(PhysicalBridge bridge, bool forced) internal void RequestWrite(PhysicalBridge bridge, bool forced)
{ {
...@@ -265,35 +201,29 @@ internal void RequestWrite(PhysicalBridge bridge, bool forced) ...@@ -265,35 +201,29 @@ internal void RequestWrite(PhysicalBridge bridge, bool forced)
} }
else if (writeQueue.Count >= 2) else if (writeQueue.Count >= 2)
{ // struggling are we? let's have some help dealing with the backlog { // struggling are we? let's have some help dealing with the backlog
ThreadPool.QueueUserWorkItem(writeOneQueue, this); Task.Run(_writeOneQueueAsync);
} }
} }
} }
} }
internal void Shutdown(SocketToken token) internal void Shutdown(SocketToken token)
{ {
Shutdown(token.Socket); Shutdown(token.Socket);
} }
private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer, TextWriter log, Tuple<Socket, ISocketCallback> tuple) private async Task EndConnectAsync(SocketConnection connection, ConnectionMultiplexer multiplexer, TextWriter log, ISocketCallback callback)
{ {
try try
{ {
bool ignoreConnect = false; bool ignoreConnect = false;
ShouldIgnoreConnect(tuple.Item2, ref ignoreConnect); var socket = connection?.Socket;
ShouldIgnoreConnect(callback, ref ignoreConnect);
if (ignoreConnect) return; if (ignoreConnect) return;
var socket = tuple.Item1;
var callback = tuple.Item2; var socketMode = callback == null ? SocketMode.Abort : await callback.ConnectedAsync(connection, log);
socket.EndConnect(ar);
var netStream = new NetworkStream(socket, false);
var socketMode = callback?.Connected(netStream, log) ?? SocketMode.Abort;
switch (socketMode) switch (socketMode)
{ {
case SocketMode.Poll:
multiplexer.LogLocked(log, "Starting poll");
OnAddRead(socket, callback);
break;
case SocketMode.Async: case SocketMode.Async:
multiplexer.LogLocked(log, "Starting read"); multiplexer.LogLocked(log, "Starting read");
try try
...@@ -313,10 +243,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer, ...@@ -313,10 +243,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer,
catch (ObjectDisposedException) catch (ObjectDisposedException)
{ {
multiplexer.LogLocked(log, "(socket shutdown)"); multiplexer.LogLocked(log, "(socket shutdown)");
if (tuple != null) if (callback != null)
{ {
try try { callback.Error(); }
{ tuple.Item2.Error(); }
catch (Exception inner) catch (Exception inner)
{ {
ConnectionMultiplexer.TraceWithoutContext(inner.Message); ConnectionMultiplexer.TraceWithoutContext(inner.Message);
...@@ -326,10 +255,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer, ...@@ -326,10 +255,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer,
catch(Exception outer) catch(Exception outer)
{ {
ConnectionMultiplexer.TraceWithoutContext(outer.Message); ConnectionMultiplexer.TraceWithoutContext(outer.Message);
if (tuple != null) if (callback != null)
{ {
try try { callback.Error(); }
{ tuple.Item2.Error(); }
catch (Exception inner) catch (Exception inner)
{ {
ConnectionMultiplexer.TraceWithoutContext(inner.Message); ConnectionMultiplexer.TraceWithoutContext(inner.Message);
...@@ -355,7 +283,7 @@ private void Shutdown(Socket socket) ...@@ -355,7 +283,7 @@ private void Shutdown(Socket socket)
} }
} }
private void WriteAllQueues() private async Task WriteAllQueuesAsync()
{ {
while (true) while (true)
{ {
...@@ -372,7 +300,7 @@ private void WriteAllQueues() ...@@ -372,7 +300,7 @@ private void WriteAllQueues()
bridge = writeQueue.Dequeue(); bridge = writeQueue.Dequeue();
} }
switch (bridge.WriteQueue(200)) switch (await bridge.WriteQueueAsync(200))
{ {
case WriteResult.MoreWork: case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite: case WriteResult.QueueEmptyAfterWrite:
...@@ -400,18 +328,22 @@ private void WriteAllQueues() ...@@ -400,18 +328,22 @@ private void WriteAllQueues()
} }
} }
private void WriteOneQueue() private Task WriteOneQueueAsync()
{ {
PhysicalBridge bridge; PhysicalBridge bridge;
lock (writeQueue) lock (writeQueue)
{ {
bridge = writeQueue.Count == 0 ? null : writeQueue.Dequeue(); bridge = writeQueue.Count == 0 ? null : writeQueue.Dequeue();
} }
if (bridge == null) return; if (bridge == null) return Task.CompletedTask;
return WriteOneQueueAsyncImpl(bridge);
}
private async Task WriteOneQueueAsyncImpl(PhysicalBridge bridge)
{
bool keepGoing; bool keepGoing;
do do
{ {
switch (bridge.WriteQueue(-1)) switch (await bridge.WriteQueueAsync(-1))
{ {
case WriteResult.MoreWork: case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite: case WriteResult.QueueEmptyAfterWrite:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment