Commit 6af1811b authored by Marc Gravell's avatar Marc Gravell

all of the outbound changes

parent 5c31bb25
......@@ -7,7 +7,7 @@
<AssemblyOriginatorKeyFile>../StackExchange.Redis.snk</AssemblyOriginatorKeyFile>
<PackageId>$(AssemblyName)</PackageId>
<Authors>Stack Exchange, Inc.; marc.gravell</Authors>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<PackageReleaseNotes>https://stackexchange.github.io/StackExchange.Redis/ReleaseNotes</PackageReleaseNotes>
<PackageProjectUrl>https://github.com/StackExchange/StackExchange.Redis/</PackageProjectUrl>
......@@ -21,7 +21,7 @@
<DefaultLanguage>en-US</DefaultLanguage>
<IncludeSymbols>false</IncludeSymbols>
<LibraryTargetFrameworks>net45;net46;netstandard2.0</LibraryTargetFrameworks>
<LibraryTargetFrameworks>net46;netstandard2.0</LibraryTargetFrameworks>
<CoreFxVersion>4.5.0</CoreFxVersion>
<xUnitVersion>2.4.0-beta.2.build3981</xUnitVersion>
</PropertyGroup>
......
......@@ -18,7 +18,7 @@
</ItemGroup>
<PropertyGroup Condition=" '$(TargetFramework)' == 'net45' or '$(TargetFramework)' == 'net46'">
<DefineConstants>$(DefineConstants);FEATURE_SOCKET_MODE_POLL;FEATURE_PERFCOUNTER;</DefineConstants>
<DefineConstants>$(DefineConstants);FEATURE_PERFCOUNTER;</DefineConstants>
</PropertyGroup>
<ItemGroup>
......@@ -27,9 +27,9 @@
<PackageReference Include="System.Reflection.Emit.ILGeneration" Version="4.3.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<!--<PackageReference Include="System.Memory" Version="$(CoreFxVersion)" />
<PackageReference Include="System.Memory" Version="$(CoreFxVersion)" />
<PackageReference Include="System.Buffers" Version="$(CoreFxVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxVersion)" />
<PackageReference Include="Pipelines.Sockets.Unofficial" Version="0.2.0-alpha-001" />-->
<PackageReference Include="Pipelines.Sockets.Unofficial" Version="0.2.0-alpha-004" />
</ItemGroup>
</Project>
\ No newline at end of file
......@@ -5,6 +5,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
......@@ -267,21 +268,19 @@ internal void KeepAlive()
}
}
internal void OnConnected(PhysicalConnection connection, TextWriter log)
internal async Task OnConnectedAsync(PhysicalConnection connection, TextWriter log)
{
Trace("OnConnected");
if (physical == connection && !isDisposed && ChangeState(State.Connecting, State.ConnectedEstablishing))
{
ServerEndPoint.OnEstablishing(connection, log);
await ServerEndPoint.OnEstablishingAsync(connection, log);
}
else
{
try
{
connection.Dispose();
}
catch
{ }
} catch { }
}
}
......@@ -552,7 +551,10 @@ internal bool WriteMessageDirect(PhysicalConnection tmp, Message next)
}
}
internal WriteResult WriteQueue(int maxWork)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static ValueTask<T> AsResult<T>(T value) => new ValueTask<T>(value);
internal async ValueTask<WriteResult> WriteQueueAsync(int maxWork)
{
bool weAreWriter = false;
PhysicalConnection conn = null;
......@@ -585,7 +587,7 @@ internal WriteResult WriteQueue(int maxWork)
Trace("Nothing to write; exiting");
if(count == 0)
{
conn.Flush(); // only flush on an empty run
await conn.FlushAsync(); // only flush on an empty run
return WriteResult.NothingToDo;
}
return WriteResult.QueueEmptyAfterWrite;
......@@ -604,7 +606,7 @@ internal WriteResult WriteQueue(int maxWork)
{
Trace("Work limit; exiting");
Trace(last != null, "Flushed up to: " + last);
conn.Flush();
await conn.FlushAsync();
break;
}
}
......
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.Security;
......@@ -10,6 +13,7 @@
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
{
......@@ -60,14 +64,10 @@ private static readonly Message
private int failureReported;
private byte[] ioBuffer = new byte[512];
private int ioBufferBytes = 0;
private int lastWriteTickCount, lastReadTickCount, lastBeatTickCount;
private int firstUnansweredWriteTickCount;
private Stream netStream, outStream;
IDuplexPipe _ioPipe;
private SocketToken socketToken;
......@@ -113,19 +113,19 @@ private enum ReadMode : byte
public void Dispose()
{
if (outStream != null)
var ioPipe = _ioPipe;
_ioPipe = null;
if(ioPipe != null)
{
Multiplexer.Trace("Disconnecting...", physicalName);
try { outStream.Close(); } catch { }
try { outStream.Dispose(); } catch { }
outStream = null;
}
if (netStream != null)
{
try { netStream.Close(); } catch { }
try { netStream.Dispose(); } catch { }
netStream = null;
try { ioPipe.Input?.CancelPendingRead(); } catch { }
try { ioPipe.Input?.Complete(); } catch { }
try { ioPipe.Output?.CancelPendingFlush(); } catch { }
try { ioPipe.Output?.Complete(); } catch { }
ioPipe.Output?.Complete();
}
if (socketToken.HasValue)
{
Multiplexer.SocketManager?.Shutdown(socketToken);
......@@ -135,15 +135,21 @@ public void Dispose()
}
OnCloseEcho();
}
public void Flush()
private async Task AwaitedFlush(ValueTask<FlushResult> flush)
{
var tmp = outStream;
await flush;
Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount);
}
public Task FlushAsync()
{
var tmp = _ioPipe?.Output;
if (tmp != null)
{
tmp.Flush();
var flush = tmp.FlushAsync();
if (!flush.IsCompletedSuccessfully) return AwaitedFlush(flush);
Interlocked.Exchange(ref lastWriteTickCount, Environment.TickCount);
}
return Task.CompletedTask;
}
public void RecordConnectionFailed(ConnectionFailureType failureType, Exception innerException = null, [CallerMemberName] string origin = null)
......@@ -197,7 +203,7 @@ void add(string lk, string sk, string v)
}
add("Origin", "origin", origin);
add("Input-Buffer", "input-buffer", ioBufferBytes.ToString());
// add("Input-Buffer", "input-buffer", _ioPipe.Input);
add("Outstanding-Responses", "outstanding", GetSentAwaitingResponseCount().ToString());
add("Last-Read", "last-read", (unchecked(now - lastRead) / 1000) + "s ago");
add("Last-Write", "last-write", (unchecked(now - lastWrite) / 1000) + "s ago");
......@@ -405,28 +411,28 @@ internal void Write(RedisKey key)
var val = key.KeyValue;
if (val is string)
{
WriteUnified(outStream, key.KeyPrefix, (string)val);
WriteUnified(_ioPipe.Output, key.KeyPrefix, (string)val);
}
else
{
WriteUnified(outStream, key.KeyPrefix, (byte[])val);
WriteUnified(_ioPipe.Output, key.KeyPrefix, (byte[])val);
}
}
internal void Write(RedisChannel channel)
{
WriteUnified(outStream, ChannelPrefix, channel.Value);
WriteUnified(_ioPipe.Output, ChannelPrefix, channel.Value);
}
internal void Write(RedisValue value)
{
if (value.IsInteger)
{
WriteUnified(outStream, (long)value);
WriteUnified(_ioPipe.Output, (long)value);
}
else
{
WriteUnified(outStream, (byte[])value);
WriteUnified(_ioPipe.Output, (byte[])value);
}
}
......@@ -437,15 +443,10 @@ internal void WriteHeader(RedisCommand command, int arguments)
{
throw ExceptionFactory.CommandDisabled(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint);
}
outStream.WriteByte((byte)'*');
// remember the time of the first write that still not followed by read
Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0);
WriteRaw(outStream, arguments + 1);
WriteUnified(outStream, commandBytes);
WriteHeader(commandBytes, arguments);
}
internal const int REDIS_MAX_ARGS = 1024 * 1024; // there is a <= 1024*1024 max constraint inside redis itself: https://github.com/antirez/redis/blob/6c60526db91e23fb2d666fc52facc9a11780a2a3/src/networking.c#L1024
internal void WriteHeader(string command, int arguments)
......@@ -455,39 +456,66 @@ internal void WriteHeader(string command, int arguments)
throw ExceptionFactory.TooManyArgs(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint, arguments + 1);
}
var commandBytes = Multiplexer.CommandMap.GetBytes(command);
if (commandBytes == null)
{
throw ExceptionFactory.CommandDisabled(Multiplexer.IncludeDetailInExceptions, command, null, Bridge.ServerEndPoint);
}
outStream.WriteByte((byte)'*');
WriteHeader(commandBytes, arguments);
}
private void WriteHeader(byte[] commandBytes, int arguments)
{
// remember the time of the first write that still not followed by read
Interlocked.CompareExchange(ref firstUnansweredWriteTickCount, Environment.TickCount, 0);
WriteRaw(outStream, arguments + 1);
WriteUnified(outStream, commandBytes);
// *{argCount}\r\n = 3 + MaxInt32TextLen
// ${cmd-len}\r\n = 3 + MaxInt32TextLen
// {cmd}\r\n = 2 + commandBytes.Length
var span = _ioPipe.Output.GetSpan(commandBytes.Length + 8 + MaxInt32TextLen + MaxInt32TextLen);
span[0] = (byte)'*';
int offset = WriteRaw(span, arguments + 1, offset: 1);
offset = WriteUnified(span, commandBytes, offset: offset);
_ioPipe.Output.Advance(offset);
}
const int MaxInt32TextLen = 11, // -2,147,483,648 (not including the commas)
MaxInt64TextLen = 20; // -9,223,372,036,854,775,808 (not including the commas)
private static void WriteRaw(Stream stream, long value, bool withLengthPrefix = false)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static int WriteCrlf(Span<byte> span, int offset)
{
span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n';
return offset;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static void WriteCrlf(PipeWriter writer)
{
var span = writer.GetSpan(2);
span[0] = (byte)'\r';
span[1] = (byte)'\n';
writer.Advance(2);
}
private static int WriteRaw(Span<byte> span, long value, bool withLengthPrefix = false, int offset = 0)
{
if (value >= 0 && value <= 9)
{
if (withLengthPrefix)
{
stream.WriteByte((byte)'1');
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'1';
offset = WriteCrlf(span, offset);
}
stream.WriteByte((byte)((int)'0' + (int)value));
span[offset++] = (byte)((int)'0' + (int)value);
}
else if (value >= 10 && value < 100)
{
if (withLengthPrefix)
{
stream.WriteByte((byte)'2');
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'2';
offset = WriteCrlf(span, offset);
}
stream.WriteByte((byte)((int)'0' + ((int)value / 10)));
stream.WriteByte((byte)((int)'0' + ((int)value % 10)));
span[offset++] = (byte)((int)'0' + ((int)value / 10));
span[offset++] = (byte)((int)'0' + ((int)value % 10));
}
else if (value >= 100 && value < 1000)
{
......@@ -497,79 +525,143 @@ private static void WriteRaw(Stream stream, long value, bool withLengthPrefix =
int tens = v % 10, hundreds = v / 10;
if (withLengthPrefix)
{
stream.WriteByte((byte)'3');
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'3';
offset = WriteCrlf(span, offset);
}
stream.WriteByte((byte)((int)'0' + hundreds));
stream.WriteByte((byte)((int)'0' + tens));
stream.WriteByte((byte)((int)'0' + units));
span[offset++] = (byte)((int)'0' + hundreds);
span[offset++] = (byte)((int)'0' + tens);
span[offset++] = (byte)((int)'0' + units);
}
else if (value < 0 && value >= -9)
{
if (withLengthPrefix)
{
stream.WriteByte((byte)'2');
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'2';
offset = WriteCrlf(span, offset);
}
stream.WriteByte((byte)'-');
stream.WriteByte((byte)((int)'0' - (int)value));
span[offset++] = (byte)'-';
span[offset++] = (byte)((int)'0' - (int)value);
}
else if (value <= -10 && value > -100)
{
if (withLengthPrefix)
{
stream.WriteByte((byte)'3');
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'3';
offset = WriteCrlf(span, offset);
}
value = -value;
stream.WriteByte((byte)'-');
stream.WriteByte((byte)((int)'0' + ((int)value / 10)));
stream.WriteByte((byte)((int)'0' + ((int)value % 10)));
span[offset++] = (byte)'-';
span[offset++] = (byte)((int)'0' + ((int)value / 10));
span[offset++] = (byte)((int)'0' + ((int)value % 10));
}
else
{
var bytes = Encoding.ASCII.GetBytes(Format.ToString(value));
if (withLengthPrefix)
unsafe
{
WriteRaw(stream, bytes.Length, false);
byte* bytes = stackalloc byte[MaxInt32TextLen];
var s = Format.ToString(value); // need an alloc-free version of this...
int len;
fixed (char* c = s)
{
len = Encoding.ASCII.GetBytes(c, s.Length, bytes, MaxInt32TextLen);
}
if (withLengthPrefix)
{
offset = WriteRaw(span, len, false, offset);
}
new ReadOnlySpan<byte>(bytes, len).CopyTo(span.Slice(offset));
offset += len;
}
stream.Write(bytes, 0, bytes.Length);
}
stream.Write(Crlf, 0, 2);
return WriteCrlf(span, offset);
}
private static void WriteUnified(Stream stream, byte[] value)
static readonly byte[] NullBulkString = Encoding.ASCII.GetBytes("$-1\r\n"), EmptyBulkString = Encoding.ASCII.GetBytes("$0\r\n\r\n");
private static void WriteUnified(PipeWriter writer, byte[] value)
{
stream.WriteByte((byte)'$');
const int MaxQuickSpanSize = 512;
// ${len}\r\n = 3 + MaxInt32TextLen
// {value}\r\n = 2 + value.Length
if (value == null)
{
WriteRaw(stream, -1); // note that not many things like this...
// special case:
writer.Write(NullBulkString);
}
else if (value.Length == 0)
{
// special case:
writer.Write(EmptyBulkString);
}
else if (value.Length <= MaxQuickSpanSize)
{
var span = writer.GetSpan(5 + MaxInt32TextLen + value.Length);
int bytes = WriteUnified(span, value);
writer.Advance(bytes);
}
else
{
WriteRaw(stream, value.Length);
stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2);
// too big to guarantee can do in a single span
var span = writer.GetSpan(3 + MaxInt32TextLen);
span[0] = (byte)'$';
int bytes = WriteRaw(span, value.LongLength, offset: 1);
writer.Advance(bytes);
writer.Write(value);
WriteCrlf(writer);
}
}
internal void WriteAsHex(byte[] value)
private static int WriteUnified(Span<byte> span, byte[] value, int offset = 0)
{
var stream = outStream;
stream.WriteByte((byte)'$');
span[offset++] = (byte)'$';
if (value == null)
{
WriteRaw(stream, -1);
offset = WriteRaw(span, -1, offset: offset); // note that not many things like this...
}
else
{
WriteRaw(stream, value.Length * 2);
for (int i = 0; i < value.Length; i++)
offset = WriteRaw(span, value.Length, offset: offset);
new ReadOnlySpan<byte>(value).CopyTo(span.Slice(offset));
offset = WriteCrlf(span, offset);
}
return offset;
}
internal void WriteSha1AsHex(byte[] value)
{
var writer = _ioPipe.Output;
if (value == null)
{
writer.Write(NullBulkString);
}
else if(value.Length == ResultProcessor.ScriptLoadProcessor.Sha1HashLength)
{
// $40\r\n = 5
// {40 bytes}\r\n = 42
var span = writer.GetSpan(47);
span[0] = (byte)'$';
span[1] = (byte)'4';
span[2] = (byte)'0';
span[3] = (byte)'\r';
span[4] = (byte)'\n';
int offset = 5;
for(int i = 0; i < value.Length; i++)
{
stream.WriteByte(ToHexNibble(value[i] >> 4));
stream.WriteByte(ToHexNibble(value[i] & 15));
var b = value[i];
span[offset++] = ToHexNibble(value[i] >> 4);
span[offset++] = ToHexNibble(value[i] & 15);
}
stream.Write(Crlf, 0, 2);
span[offset++] = (byte)'\r';
span[offset++] = (byte)'\n';
writer.Advance(offset);
}
else
{
throw new InvalidOperationException("Invalid SHA1 length: " + value.Length);
}
}
......@@ -578,90 +670,119 @@ internal static byte ToHexNibble(int value)
return value < 10 ? (byte)('0' + value) : (byte)('a' - 10 + value);
}
private void WriteUnified(Stream stream, byte[] prefix, string value)
private void WriteUnified(PipeWriter writer, byte[] prefix, string value)
{
stream.WriteByte((byte)'$');
if (value == null)
{
WriteRaw(stream, -1); // note that not many things like this...
// special case
writer.Write(NullBulkString);
}
else
{
int encodedLength = Encoding.UTF8.GetByteCount(value);
if (prefix == null)
// ${total-len}\r\n 3 + MaxInt32TextLen
// {prefix}{value}\r\n
int encodedLength = Encoding.UTF8.GetByteCount(value),
prefixLength = prefix == null ? 0 : prefix.Length,
totalLength = prefixLength + encodedLength;
if (totalLength == 0)
{
WriteRaw(stream, encodedLength);
WriteRaw(stream, value, encodedLength);
stream.Write(Crlf, 0, 2);
// special-case
writer.Write(EmptyBulkString);
}
else
{
WriteRaw(stream, prefix.Length + encodedLength);
stream.Write(prefix, 0, prefix.Length);
WriteRaw(stream, value, encodedLength);
stream.Write(Crlf, 0, 2);
var span = writer.GetSpan(3 + MaxInt32TextLen);
span[0] = (byte)'$';
int bytes = WriteRaw(span, totalLength, offset: 1);
writer.Advance(bytes);
if (prefixLength != 0) writer.Write(prefix);
if (encodedLength != 0) WriteRaw(writer, value, encodedLength);
WriteCrlf(writer);
}
}
}
private unsafe void WriteRaw(Stream stream, string value, int encodedLength)
private unsafe void WriteRaw(PipeWriter writer, string value, int encodedLength)
{
if (encodedLength <= ScratchSize)
{
int bytes = Encoding.UTF8.GetBytes(value, 0, value.Length, outScratch, 0);
stream.Write(outScratch, 0, bytes);
}
else
const int MaxQuickEncodeSize = 512;
fixed (char* cPtr = value)
{
fixed (char* c = value)
fixed (byte* b = outScratch)
int totalBytes;
if (encodedLength <= MaxQuickEncodeSize)
{
// encode directly in one hit
var span = writer.GetSpan(encodedLength);
fixed (byte* bPtr = &span[0])
{
totalBytes = Encoding.UTF8.GetBytes(cPtr, value.Length, bPtr, encodedLength);
}
writer.Advance(encodedLength);
}
else
{
int charsRemaining = value.Length, charOffset = 0, bytesWritten;
while (charsRemaining > Scratch_CharsPerBlock)
// use an encoder in a loop
outEncoder.Reset();
int charsRemaining = value.Length, charOffset = 0;
totalBytes = 0;
while (charsRemaining != 0)
{
bytesWritten = outEncoder.GetBytes(c + charOffset, Scratch_CharsPerBlock, b, ScratchSize, false);
stream.Write(outScratch, 0, bytesWritten);
charOffset += Scratch_CharsPerBlock;
charsRemaining -= Scratch_CharsPerBlock;
// note: at most 4 bytes per UTF8 character, despite what UTF8.GetMaxByteCount says
var span = writer.GetSpan(4); // get *some* memory - at least enough for 1 character (but hopefully lots more)
int bytesWritten, charsToWrite = span.Length >> 2; // assume worst case, because the API sucks
fixed (byte* bPtr = &span[0])
{
bytesWritten = outEncoder.GetBytes(cPtr + charOffset, charsToWrite, bPtr, span.Length, false);
}
writer.Advance(bytesWritten);
totalBytes += bytesWritten;
charOffset += charsToWrite;
charsRemaining -= charsRemaining;
}
bytesWritten = outEncoder.GetBytes(c + charOffset, charsRemaining, b, ScratchSize, true);
if (bytesWritten != 0) stream.Write(outScratch, 0, bytesWritten);
}
Debug.Assert(totalBytes == encodedLength);
}
}
private const int ScratchSize = 512;
private static readonly int Scratch_CharsPerBlock = ScratchSize / Encoding.UTF8.GetMaxByteCount(1);
private readonly byte[] outScratch = new byte[ScratchSize];
private readonly Encoder outEncoder = Encoding.UTF8.GetEncoder();
private static void WriteUnified(Stream stream, byte[] prefix, byte[] value)
private static void WriteUnified(PipeWriter writer, byte[] prefix, byte[] value)
{
stream.WriteByte((byte)'$');
if (value == null)
{
WriteRaw(stream, -1); // note that not many things like this...
}
else if (prefix == null)
{
WriteRaw(stream, value.Length);
stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2);
// ${total-len}\r\n
// {prefix}{value}\r\n
if (prefix == null || prefix.Length == 0 || value == null)
{ // if no prefix, just use the non-prefixed version;
// even if prefixed, a null value writes as null, so can use the non-prefixed version
WriteUnified(writer, value);
}
else
{
WriteRaw(stream, prefix.Length + value.Length);
stream.Write(prefix, 0, prefix.Length);
stream.Write(value, 0, value.Length);
stream.Write(Crlf, 0, 2);
var span = writer.GetSpan(3 + MaxInt32TextLen); // note even with 2 max-len, we're still in same text range
span[0] = (byte)'$';
int bytes = WriteRaw(span, prefix.LongLength + value.LongLength, offset: 1);
writer.Advance(bytes);
writer.Write(prefix);
writer.Write(value);
span = writer.GetSpan(2);
WriteCrlf(span, 0);
writer.Advance(2);
}
}
private static void WriteUnified(Stream stream, long value)
private static void WriteUnified(PipeWriter writer, long value)
{
// note from specification: A client sends to the Redis server a RESP Array consisting of just Bulk Strings.
// (i.e. we can't just send ":123\r\n", we need to send "$3\r\n123\r\n"
stream.WriteByte((byte)'$');
WriteRaw(stream, value, withLengthPrefix: true);
// ${asc-len}\r\n = 3 + MaxInt32TextLen
// {asc}\r\n = MaxInt64TextLen + 2
var span = writer.GetSpan(5 + MaxInt32TextLen + MaxInt64TextLen);
span[0] = (byte)'$';
var bytes = WriteRaw(span, value, withLengthPrefix: true, offset: 1);
writer.Advance(bytes);
}
private void BeginReading()
......@@ -720,7 +841,7 @@ private static LocalCertificateSelectionCallback GetAmbientCertificateCallback()
return null;
}
SocketMode ISocketCallback.Connected(Stream stream, TextWriter log)
async ValueTask<SocketMode> ISocketCallback.ConnectedAsync(IDuplexPipe pipe, TextWriter log)
{
try
{
......@@ -735,36 +856,37 @@ SocketMode ISocketCallback.Connected(Stream stream, TextWriter log)
if (config.Ssl)
{
Multiplexer.LogLocked(log, "Configuring SSL");
var host = config.SslHost;
if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(Bridge.ServerEndPoint.EndPoint);
var ssl = new SslStream(stream, false, config.CertificateValidationCallback,
config.CertificateSelectionCallback ?? GetAmbientCertificateCallback(),
EncryptionPolicy.RequireEncryption);
try
{
ssl.AuthenticateAsClient(host, config.SslProtocols);
Multiplexer.LogLocked(log, $"SSL connection established successfully using protocol: {ssl.SslProtocol}");
}
catch (AuthenticationException authexception)
{
RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, authexception);
Multiplexer.Trace("Encryption failure");
return SocketMode.Abort;
}
stream = ssl;
socketMode = SocketMode.Async;
throw new NotImplementedException("TLS");
//Multiplexer.LogLocked(log, "Configuring SSL");
//var host = config.SslHost;
//if (string.IsNullOrWhiteSpace(host)) host = Format.ToStringHostOnly(Bridge.ServerEndPoint.EndPoint);
//var ssl = new SslStream(stream, false, config.CertificateValidationCallback,
// config.CertificateSelectionCallback ?? GetAmbientCertificateCallback(),
// EncryptionPolicy.RequireEncryption);
//try
//{
// ssl.AuthenticateAsClient(host, config.SslProtocols);
// Multiplexer.LogLocked(log, $"SSL connection established successfully using protocol: {ssl.SslProtocol}");
//}
//catch (AuthenticationException authexception)
//{
// RecordConnectionFailed(ConnectionFailureType.AuthenticationFailure, authexception);
// Multiplexer.Trace("Encryption failure");
// return SocketMode.Abort;
//}
//stream = ssl;
//socketMode = SocketMode.Async;
}
OnWrapForLogging(ref stream, physicalName);
OnWrapForLogging(ref pipe, physicalName);
int bufferSize = config.WriteBuffer;
netStream = stream;
outStream = bufferSize <= 0 ? stream : new BufferedStream(stream, bufferSize);
_ioPipe = pipe;
Multiplexer.LogLocked(log, "Connected {0}", Bridge);
Bridge.OnConnected(this, log);
await Bridge.OnConnectedAsync(this, log);
return socketMode;
}
catch (Exception ex)
......@@ -884,7 +1006,7 @@ void ISocketCallback.OnHeartbeat()
}
}
partial void OnWrapForLogging(ref Stream stream, string name);
partial void OnWrapForLogging(ref IDuplexPipe pipe, string name);
private int ProcessBuffer(byte[] underlying, ref int offset, ref int count)
{
int messageCount = 0;
......
using System;
//using System;
namespace StackExchange.Redis
{
internal static class PlatformHelper
{
public static bool IsMono { get; } = Type.GetType("Mono.Runtime") != null;
//namespace StackExchange.Redis
//{
// internal static class PlatformHelper
// {
// public static bool IsMono { get; } = Type.GetType("Mono.Runtime") != null;
public static bool IsUnix { get; } = (int)Environment.OSVersion.Platform == 4
|| (int)Environment.OSVersion.Platform == 6
|| (int)Environment.OSVersion.Platform == 128;
// public static bool IsUnix { get; } = (int)Environment.OSVersion.Platform == 4
// || (int)Environment.OSVersion.Platform == 6
// || (int)Environment.OSVersion.Platform == 128;
public static SocketMode DefaultSocketMode = IsMono && IsUnix ? SocketMode.Async : SocketMode.Poll;
}
}
// public static SocketMode DefaultSocketMode = IsMono && IsUnix ? SocketMode.Async : SocketMode.Poll;
// }
//}
......@@ -2521,6 +2521,7 @@ public ScriptEvalMessage(int db, CommandFlags flags, byte[] hash, RedisKey[] key
: this(db, flags, RedisCommand.EVAL, null, hash, keys, values)
{
if (hash == null) throw new ArgumentNullException(nameof(hash));
if (hash.Length != ResultProcessor.ScriptLoadProcessor.Sha1HashLength) throw new ArgumentOutOfRangeException(nameof(hash), "Invalid hash length");
}
private ScriptEvalMessage(int db, CommandFlags flags, RedisCommand command, string script, byte[] hexHash, RedisKey[] keys, RedisValue[] values)
......@@ -2571,7 +2572,7 @@ internal override void WriteImpl(PhysicalConnection physical)
if (hexHash != null)
{
physical.WriteHeader(RedisCommand.EVALSHA, 2 + keys.Length + values.Length);
physical.WriteAsHex(hexHash);
physical.WriteSha1AsHex(hexHash);
}
else if (asciiHash != null)
{
......
using System;
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
......@@ -250,7 +250,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
{
// need to get those sent ASAP; if they are stuck in the buffers, we die
multiplexer.Trace("Flushing and waiting for precondition responses");
connection.Flush();
connection.FlushAsync().Wait();
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{
if (!AreAllConditionsSatisfied(multiplexer))
......@@ -297,7 +297,7 @@ public IEnumerable<Message> GetMessages(PhysicalConnection connection)
if (explicitCheckForQueued && lastBox != null)
{
multiplexer.Trace("Flushing and waiting for precondition+queued responses");
connection.Flush(); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary)
connection.FlushAsync().Wait(); // make sure they get sent, so we can check for QUEUED (and the pre-conditions if necessary)
if (Monitor.Wait(lastBox, multiplexer.TimeoutMilliseconds))
{
if (!AreAllConditionsSatisfied(multiplexer))
......@@ -475,4 +475,4 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
// return ExecuteTransactionAsync(flags);
// }
//}
}
\ No newline at end of file
}
......@@ -393,11 +393,13 @@ internal static bool IsSHA1(string script)
return script != null && sha1.IsMatch(script);
}
internal const int Sha1HashLength = 20;
internal static byte[] ParseSHA1(byte[] value)
{
if (value?.Length == 40)
if (value?.Length == Sha1HashLength * 2)
{
var tmp = new byte[20];
var tmp = new byte[Sha1HashLength];
int charIndex = 0;
for (int i = 0; i < tmp.Length; i++)
{
......@@ -412,9 +414,9 @@ internal static byte[] ParseSHA1(byte[] value)
internal static byte[] ParseSHA1(string value)
{
if (value?.Length == 40 && sha1.IsMatch(value))
if (value?.Length == (Sha1HashLength * 2) && sha1.IsMatch(value))
{
var tmp = new byte[20];
var tmp = new byte[Sha1HashLength];
int charIndex = 0;
for (int i = 0; i < tmp.Length; i++)
{
......@@ -442,7 +444,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
{
case ResultType.BulkString:
var asciiHash = result.GetBlob();
if (asciiHash == null || asciiHash.Length != 40) return false;
if (asciiHash == null || asciiHash.Length != (Sha1HashLength * 2)) return false;
byte[] hash = null;
if (!message.IsInternalCall)
......
......@@ -450,19 +450,33 @@ internal bool IsSelectable(RedisCommand command)
return bridge?.IsConnected == true;
}
internal void OnEstablishing(PhysicalConnection connection, TextWriter log)
internal Task OnEstablishingAsync(PhysicalConnection connection, TextWriter log)
{
try
{
if (connection == null) return;
Handshake(connection, log);
if (connection == null) return Task.CompletedTask;
var handshake = HandshakeAsync(connection, log);
if (handshake.Status != TaskStatus.RanToCompletion)
return OnEstablishingAsyncAwaited(connection, handshake);
}
catch (Exception ex)
{
connection.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
}
return Task.CompletedTask;
}
async Task OnEstablishingAsyncAwaited(PhysicalConnection connection, Task handshake)
{
try
{
await handshake;
}
catch (Exception ex)
{
connection.RecordConnectionFailed(ConnectionFailureType.InternalFailure, ex);
}
}
internal void OnFullyEstablished(PhysicalConnection connection)
{
try
......@@ -627,13 +641,13 @@ private PhysicalBridge CreateBridge(ConnectionType type, TextWriter log)
return bridge;
}
private void Handshake(PhysicalConnection connection, TextWriter log)
private Task HandshakeAsync(PhysicalConnection connection, TextWriter log)
{
Multiplexer.LogLocked(log, "Server handshake");
if (connection == null)
{
Multiplexer.Trace("No connection!?");
return;
return Task.CompletedTask;
}
Message msg;
string password = Multiplexer.RawConfig.Password;
......@@ -684,7 +698,7 @@ private void Handshake(PhysicalConnection connection, TextWriter log)
}
}
Multiplexer.LogLocked(log, "Flushing outbound buffer");
connection.Flush();
return connection.FlushAsync();
}
private void SetConfig<T>(ref T field, T value, [CallerMemberName] string caller = null)
......
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Pipelines.Sockets.Unofficial;
namespace StackExchange.Redis
{
internal enum SocketMode
{
Abort,
[Obsolete("just don't", error: true)]
Poll,
Async
}
......@@ -22,9 +26,9 @@ internal partial interface ISocketCallback
/// <summary>
/// Indicates that a socket has connected
/// </summary>
/// <param name="stream">The network stream for this socket.</param>
/// <param name="pipe">The network stream for this socket.</param>
/// <param name="log">A text logger to write to.</param>
SocketMode Connected(Stream stream, TextWriter log);
ValueTask<SocketMode> ConnectedAsync(IDuplexPipe pipe, TextWriter log);
/// <summary>
/// Indicates that the socket has signalled an error condition
......@@ -104,16 +108,6 @@ internal enum ManagerState
ProcessErrorQueue,
}
private static readonly ParameterizedThreadStart writeAllQueues = context =>
{
try { ((SocketManager)context).WriteAllQueues(); } catch { }
};
private static readonly WaitCallback writeOneQueue = context =>
{
try { ((SocketManager)context).WriteOneQueue(); } catch { }
};
private readonly Queue<PhysicalBridge> writeQueue = new Queue<PhysicalBridge>();
private bool isDisposed;
private readonly bool useHighPrioritySocketThreads = true;
......@@ -140,14 +134,12 @@ public SocketManager(string name, bool useHighPrioritySocketThreads)
Name = name;
this.useHighPrioritySocketThreads = useHighPrioritySocketThreads;
// we need a dedicated writer, because when under heavy ambient load
// (a busy asp.net site, for example), workers are not reliable enough
var dedicatedWriter = new Thread(writeAllQueues, 32 * 1024); // don't need a huge stack;
dedicatedWriter.Priority = useHighPrioritySocketThreads ? ThreadPriority.AboveNormal : ThreadPriority.Normal;
dedicatedWriter.Name = name + ":Write";
dedicatedWriter.IsBackground = true; // should not keep process alive
dedicatedWriter.Start(this); // will self-exit when disposed
_writeOneQueueAsync = () => WriteOneQueueAsync();
Task.Run(() => WriteAllQueuesAsync());
}
private readonly Func<Task> _writeOneQueueAsync;
private enum CallbackOperation
{
......@@ -171,54 +163,16 @@ public void Dispose()
internal SocketToken BeginConnect(EndPoint endpoint, ISocketCallback callback, ConnectionMultiplexer multiplexer, TextWriter log)
{
void RunWithCompletionType(Func<AsyncCallback, IAsyncResult> beginAsync, AsyncCallback asyncCallback)
{
void proxyCallback(IAsyncResult ar)
{
if (!ar.CompletedSynchronously)
{
asyncCallback(ar);
}
}
var result = beginAsync(proxyCallback);
if (result.CompletedSynchronously)
{
result.AsyncWaitHandle.WaitOne();
asyncCallback(result);
}
}
var addressFamily = endpoint.AddressFamily == AddressFamily.Unspecified ? AddressFamily.InterNetwork : endpoint.AddressFamily;
var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
SetFastLoopbackOption(socket);
socket.NoDelay = true;
try
{
var formattedEndpoint = Format.ToString(endpoint);
var tuple = Tuple.Create(socket, callback);
multiplexer.LogLocked(log, "BeginConnect: {0}", formattedEndpoint);
// A work-around for a Mono bug in BeginConnect(EndPoint endpoint, AsyncCallback callback, object state)
if (endpoint is DnsEndPoint dnsEndpoint)
{
RunWithCompletionType(
cb => socket.BeginConnect(dnsEndpoint.Host, dnsEndpoint.Port, cb, tuple),
ar => {
multiplexer.LogLocked(log, "EndConnect: {0}", formattedEndpoint);
EndConnectImpl(ar, multiplexer, log, tuple);
multiplexer.LogLocked(log, "Connect complete: {0}", formattedEndpoint);
});
}
else
{
RunWithCompletionType(
cb => socket.BeginConnect(endpoint, cb, tuple),
ar => {
multiplexer.LogLocked(log, "EndConnect: {0}", formattedEndpoint);
EndConnectImpl(ar, multiplexer, log, tuple);
multiplexer.LogLocked(log, "Connect complete: {0}", formattedEndpoint);
});
}
SocketConnection.ConnectAsync(endpoint, PipeOptions.Default,
conn => EndConnectAsync(conn, multiplexer, log, callback), socket);
}
catch (NotImplementedException ex)
{
......@@ -232,25 +186,7 @@ void proxyCallback(IAsyncResult ar)
return token;
}
internal void SetFastLoopbackOption(Socket socket)
{
// SIO_LOOPBACK_FAST_PATH (https://msdn.microsoft.com/en-us/library/windows/desktop/jj841212%28v=vs.85%29.aspx)
// Speeds up localhost operations significantly. OK to apply to a socket that will not be hooked up to localhost,
// or will be subject to WFP filtering.
const int SIO_LOOPBACK_FAST_PATH = -1744830448;
// windows only
if (Environment.OSVersion.Platform == PlatformID.Win32NT)
{
// Win8/Server2012+ only
var osVersion = Environment.OSVersion.Version;
if (osVersion.Major > 6 || (osVersion.Major == 6 && osVersion.Minor >= 2))
{
byte[] optionInValue = BitConverter.GetBytes(1);
socket.IOControl(SIO_LOOPBACK_FAST_PATH, optionInValue, null);
}
}
}
internal void RequestWrite(PhysicalBridge bridge, bool forced)
{
......@@ -265,35 +201,29 @@ internal void RequestWrite(PhysicalBridge bridge, bool forced)
}
else if (writeQueue.Count >= 2)
{ // struggling are we? let's have some help dealing with the backlog
ThreadPool.QueueUserWorkItem(writeOneQueue, this);
Task.Run(_writeOneQueueAsync);
}
}
}
}
internal void Shutdown(SocketToken token)
{
Shutdown(token.Socket);
}
private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer, TextWriter log, Tuple<Socket, ISocketCallback> tuple)
private async Task EndConnectAsync(SocketConnection connection, ConnectionMultiplexer multiplexer, TextWriter log, ISocketCallback callback)
{
try
{
bool ignoreConnect = false;
ShouldIgnoreConnect(tuple.Item2, ref ignoreConnect);
var socket = connection?.Socket;
ShouldIgnoreConnect(callback, ref ignoreConnect);
if (ignoreConnect) return;
var socket = tuple.Item1;
var callback = tuple.Item2;
socket.EndConnect(ar);
var netStream = new NetworkStream(socket, false);
var socketMode = callback?.Connected(netStream, log) ?? SocketMode.Abort;
var socketMode = callback == null ? SocketMode.Abort : await callback.ConnectedAsync(connection, log);
switch (socketMode)
{
case SocketMode.Poll:
multiplexer.LogLocked(log, "Starting poll");
OnAddRead(socket, callback);
break;
case SocketMode.Async:
multiplexer.LogLocked(log, "Starting read");
try
......@@ -313,10 +243,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer,
catch (ObjectDisposedException)
{
multiplexer.LogLocked(log, "(socket shutdown)");
if (tuple != null)
if (callback != null)
{
try
{ tuple.Item2.Error(); }
try { callback.Error(); }
catch (Exception inner)
{
ConnectionMultiplexer.TraceWithoutContext(inner.Message);
......@@ -326,10 +255,9 @@ private void EndConnectImpl(IAsyncResult ar, ConnectionMultiplexer multiplexer,
catch(Exception outer)
{
ConnectionMultiplexer.TraceWithoutContext(outer.Message);
if (tuple != null)
if (callback != null)
{
try
{ tuple.Item2.Error(); }
try { callback.Error(); }
catch (Exception inner)
{
ConnectionMultiplexer.TraceWithoutContext(inner.Message);
......@@ -355,7 +283,7 @@ private void Shutdown(Socket socket)
}
}
private void WriteAllQueues()
private async Task WriteAllQueuesAsync()
{
while (true)
{
......@@ -372,7 +300,7 @@ private void WriteAllQueues()
bridge = writeQueue.Dequeue();
}
switch (bridge.WriteQueue(200))
switch (await bridge.WriteQueueAsync(200))
{
case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite:
......@@ -400,18 +328,22 @@ private void WriteAllQueues()
}
}
private void WriteOneQueue()
private Task WriteOneQueueAsync()
{
PhysicalBridge bridge;
lock (writeQueue)
{
bridge = writeQueue.Count == 0 ? null : writeQueue.Dequeue();
}
if (bridge == null) return;
if (bridge == null) return Task.CompletedTask;
return WriteOneQueueAsyncImpl(bridge);
}
private async Task WriteOneQueueAsyncImpl(PhysicalBridge bridge)
{
bool keepGoing;
do
{
switch (bridge.WriteQueue(-1))
switch (await bridge.WriteQueueAsync(-1))
{
case WriteResult.MoreWork:
case WriteResult.QueueEmptyAfterWrite:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment