Commit e815fcb0 authored by Marc Gravell's avatar Marc Gravell

Resumable scanning operators; KEYS first

parent 11af3494
using System.Linq;
using NUnit.Framework;
using System;
using System.Threading;
using System.Collections.Generic;
namespace StackExchange.Redis.Tests
{
[TestFixture]
public class Scans : TestBase
{
[Test]
[TestCase(true)]
[TestCase(false)]
public void KeysScan(bool supported)
{
string[] disabledCommands = supported ? null : new[] { "scan" };
using (var conn = Create(disabledCommands: disabledCommands, allowAdmin: true))
{
const int DB = 7;
var db = conn.GetDatabase(DB);
var server = GetServer(conn);
server.FlushDatabase(DB);
for(int i = 0 ; i < 100 ; i++)
{
db.StringSet("KeysScan:" + i, Guid.NewGuid().ToString(), flags: CommandFlags.FireAndForget);
}
var seq = server.Keys(DB, pageSize:50);
bool isScanning = seq is IScanning;
Assert.AreEqual(supported, isScanning, "scanning");
Assert.AreEqual(100, seq.Distinct().Count());
Assert.AreEqual(100, seq.Distinct().Count());
Assert.AreEqual(100, server.Keys(DB, "KeysScan:*").Distinct().Count());
// 7, 70, 71, ..., 79
Assert.AreEqual(11, server.Keys(DB, "KeysScan:7*").Distinct().Count());
}
}
public void ScansIScanning()
{
using (var conn = Create(allowAdmin: true))
{
const int DB = 7;
var db = conn.GetDatabase(DB);
var server = GetServer(conn);
server.FlushDatabase(DB);
for (int i = 0; i < 100; i++)
{
db.StringSet("ScansRepeatable:" + i, Guid.NewGuid().ToString(), flags: CommandFlags.FireAndForget);
}
var seq = server.Keys(DB, pageSize: 15);
using(var iter = seq.GetEnumerator())
{
IScanning s0 = (IScanning)seq, s1 = (IScanning)iter;
Assert.AreEqual(15, s0.PageSize);
Assert.AreEqual(15, s1.PageSize);
// start at zero
Assert.AreEqual(0, s0.CurrentCursor);
Assert.AreEqual(0, s0.NextCursor);
Assert.AreEqual(s0.CurrentCursor, s1.CurrentCursor);
Assert.AreEqual(s0.NextCursor, s1.NextCursor);
for(int i = 0 ; i < 47 ; i++)
{
Assert.IsTrue(iter.MoveNext());
}
// non-zero in the middle
Assert.AreNotEqual(0, s0.CurrentCursor);
Assert.AreNotEqual(0, s0.NextCursor);
Assert.AreEqual(s0.CurrentCursor, s1.CurrentCursor);
Assert.AreEqual(s0.NextCursor, s1.NextCursor);
Assert.AreNotEqual(s1.CurrentCursor, s1.NextCursor, "iter");
Assert.AreNotEqual(s0.CurrentCursor, s0.NextCursor, "seq");
for (int i = 0; i < 53; i++)
{
Assert.IsTrue(iter.MoveNext());
}
// zero "next" at the end
Assert.IsFalse(iter.MoveNext());
Assert.AreEqual(0, s0.NextCursor);
Assert.AreEqual(0, s1.NextCursor);
Assert.AreNotEqual(0, s0.CurrentCursor);
Assert.AreNotEqual(0, s1.CurrentCursor);
}
}
}
public void ScanResume()
{
using (var conn = Create(allowAdmin: true))
{
const int DB = 7;
var db = conn.GetDatabase(DB);
var server = GetServer(conn);
server.FlushDatabase(DB);
int i;
for (i = 0; i < 100; i++)
{
db.StringSet("ScanResume:" + i, Guid.NewGuid().ToString(), flags: CommandFlags.FireAndForget);
}
var expected = new HashSet<string>();
long snap = 0;
i = 0;
var seq = server.Keys(DB, pageSize: 15);
foreach(var key in seq)
{
i++;
if (i < 57) continue;
if (i == 57)
{
snap = ((IScanning)seq).CurrentCursor;
}
expected.Add((string)key);
}
Assert.AreNotEqual(43, expected.Count);
Assert.AreNotEqual(0, snap);
seq = server.Keys(DB, pageSize: 15, cursor: snap);
int count = 0;
foreach(var key in seq)
{
expected.Remove((string)key);
count++;
}
Assert.AreEqual(0, expected.Count);
Assert.AreEqual(55, count); // expect some overlap due to paged, etc
}
}
[Test]
[TestCase(true)]
[TestCase(false)]
......

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 2013
VisualStudioVersion = 12.0.30501.0
VisualStudioVersion = 12.0.30723.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis", "StackExchange.Redis\StackExchange.Redis.csproj", "{7CEC07F2-8C03-4C42-B048-738B215824C1}"
EndProject
......@@ -24,7 +24,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Redis Configs", "Redis Conf
Redis Configs\redis-cli master.cmd = Redis Configs\redis-cli master.cmd
Redis Configs\redis-cli secure.cmd = Redis Configs\redis-cli secure.cmd
Redis Configs\redis-cli slave.cmd = Redis Configs\redis-cli slave.cmd
Redis Configs\redis-server alll local.cmd = Redis Configs\redis-server alll local.cmd
Redis Configs\redis-server all local.cmd = Redis Configs\redis-server all local.cmd
Redis Configs\redis-server master.cmd = Redis Configs\redis-server master.cmd
Redis Configs\redis-server secure.cmd = Redis Configs\redis-server secure.cmd
Redis Configs\redis-server slave.cmd = Redis Configs\redis-server slave.cmd
......
......@@ -16,6 +16,27 @@ public partial interface IRedis : IRedisAsync
TimeSpan Ping(CommandFlags flags = CommandFlags.None);
}
/// <summary>
/// Represents a resumable, cursor-based scanning operation
/// </summary>
public interface IScanning
{
/// <summary>
/// Returns the cursor that represents the *active* page of results (not the pending/next page of results)
/// </summary>
long CurrentCursor { get; }
/// <summary>
/// Returns the cursor for the *pending/next* page of results
/// </summary>
long NextCursor { get; }
/// <summary>
/// The page size of the current operation
/// </summary>
int PageSize { get; }
}
[Conditional("DEBUG")]
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
internal class IgnoreNamePrefixAttribute : Attribute
......
......@@ -229,7 +229,15 @@ public partial interface IServer : IRedis
/// <remarks>Warning: consider KEYS as a command that should only be used in production environments with extreme care.</remarks>
/// <remarks>http://redis.io/commands/keys</remarks>
/// <remarks>http://redis.io/commands/scan</remarks>
IEnumerable<RedisKey> Keys(int database = 0, RedisValue pattern = default(RedisValue), int pageSize = 10, CommandFlags flags = CommandFlags.None);
IEnumerable<RedisKey> Keys(int database, RedisValue pattern, int pageSize, CommandFlags flags);
/// <summary>
/// Returns all keys matching pattern; the KEYS or SCAN commands will be used based on the server capabilities.
/// </summary>
/// <remarks>Warning: consider KEYS as a command that should only be used in production environments with extreme care.</remarks>
/// <remarks>http://redis.io/commands/keys</remarks>
/// <remarks>http://redis.io/commands/scan</remarks>
IEnumerable<RedisKey> Keys(int database = 0, RedisValue pattern = default(RedisValue), int pageSize = 10, long cursor = 0, CommandFlags flags = CommandFlags.None);
/// <summary>
/// Return the time of the last DB save executed with success. A client may check if a BGSAVE command succeeded reading the LASTSAVE value, then issuing a BGSAVE command and checking at regular intervals every N seconds if LASTSAVE changed.
......
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
......@@ -137,5 +139,216 @@ private ResultProcessor.TimingProcessor.TimerMessage GetTimerMessage(CommandFlag
// note: this usually means: twemproxy - in which case we're fine anyway, since the proxy does the routing
return ResultProcessor.TimingProcessor.CreateMessage(0, flags, RedisCommand.EXISTS, (RedisValue)multiplexer.UniqueId);
}
internal abstract class CursorEnumerableBase
{
internal const int DefaultPageSize = 10;
internal static bool IsNil(RedisValue pattern)
{
if (pattern.IsNullOrEmpty) return true;
if (pattern.IsInteger) return false;
byte[] rawValue = pattern;
return rawValue.Length == 1 && rawValue[0] == '*';
}
}
internal abstract class CursorEnumerableBase<T> : CursorEnumerableBase, IEnumerable<T>, IScanning
{
private readonly RedisBase redis;
private readonly ServerEndPoint server;
protected readonly int db;
protected readonly CommandFlags flags;
protected readonly int pageSize;
protected readonly long initialCursor;
protected CursorEnumerableBase(RedisBase redis, ServerEndPoint server, int db, int pageSize, long cursor, CommandFlags flags)
{
this.redis = redis;
this.server = server;
this.db = db;
this.pageSize = pageSize;
this.flags = flags;
this.initialCursor = cursor;
}
public IEnumerator<T> GetEnumerator()
{
return new CursorEnumerator(this);
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
internal struct ScanResult
{
public readonly long Cursor;
public readonly T[] Values;
public ScanResult(long cursor, T[] values)
{
this.Cursor = cursor;
this.Values = values;
}
}
protected abstract Message CreateMessage(long cursor);
private long currentCursor, nextCursor;
internal void SetPosition(long current, long next)
{
Interlocked.Exchange(ref currentCursor, current);
Interlocked.Exchange(ref nextCursor, next);
}
protected abstract ResultProcessor<ScanResult> Processor { get; }
protected ScanResult GetNextPageSync(long cursor)
{
return redis.ExecuteSync(CreateMessage(cursor), Processor, server);
}
protected Task<ScanResult> GetNextPageAsync(long cursor)
{
return redis.ExecuteAsync(CreateMessage(cursor), Processor, server);
}
protected ScanResult Wait(Task<ScanResult> pending)
{
return redis.Wait(pending);
}
class CursorEnumerator : IEnumerator<T>, IScanning
{
private CursorEnumerableBase<T> parent;
public CursorEnumerator(CursorEnumerableBase<T> parent)
{
if (parent == null) throw new ArgumentNullException("parent");
this.parent = parent;
Reset();
}
public T Current
{
get { return page[pageIndex]; }
}
void IDisposable.Dispose() { parent = null; state = State.Disposed; }
object System.Collections.IEnumerator.Current
{
get { return page[pageIndex]; ; }
}
private bool SimpleNext()
{
if (page != null && ++pageIndex < page.Length)
{
// first of a new page? cool; start a new background op, because we're about to exit the iterator
if (pageIndex == 0 && pending == null && nextCursor != 0)
{
pending = parent.GetNextPageAsync(nextCursor);
}
return true;
}
return false;
}
T[] page;
Task<ScanResult> pending;
int pageIndex;
private long currentCursor, nextCursor;
private State state;
private enum State : byte
{
Initial,
Running,
Complete,
Disposed,
}
void ProcessReply(ScanResult result)
{
pending = null;
page = result.Values;
pageIndex = -1;
parent.SetPosition(currentCursor = nextCursor, nextCursor = result.Cursor);
}
public bool MoveNext()
{
switch(state)
{
case State.Complete:
return false;
case State.Initial:
ProcessReply(parent.GetNextPageSync(nextCursor));
state = State.Running;
goto case State.Running;
case State.Running:
// are we working through the current buffer?
if (SimpleNext()) return true;
// do we have an outstanding operation? wait for the background task to finish
while (pending != null)
{
ProcessReply(parent.Wait(pending));
if (SimpleNext()) return true;
}
// nothing outstanding? wait synchronously
while(nextCursor != 0)
{
ProcessReply(parent.GetNextPageSync(nextCursor));
if (SimpleNext()) return true;
}
// we're exhausted
state = State.Complete;
return false;
case State.Disposed:
default:
throw new ObjectDisposedException(GetType().Name);
}
}
public void Reset()
{
if(state == State.Disposed) throw new ObjectDisposedException(GetType().Name);
nextCursor = currentCursor = parent.initialCursor;
state = State.Initial;
page = null;
pageIndex = -1;
pending = null;
}
long IScanning.CurrentCursor
{
get { return currentCursor; }
}
long IScanning.NextCursor
{
get { return nextCursor; }
}
int IScanning.PageSize
{
get { return parent.pageSize; }
}
}
long IScanning.CurrentCursor
{
get { return Interlocked.Read(ref currentCursor); }
}
long IScanning.NextCursor
{
get { return Interlocked.Read(ref nextCursor); }
}
int IScanning.PageSize
{
get { return pageSize; }
}
}
}
}
......@@ -5,6 +5,7 @@
using System.Net;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace StackExchange.Redis
......@@ -268,18 +269,24 @@ public Task<string> InfoRawAsync(RedisValue section = default(RedisValue), Comma
return ExecuteAsync(msg, ResultProcessor.String);
}
public IEnumerable<RedisKey> Keys(int database = 0, RedisValue pattern = default(RedisValue), int pageSize = KeysScanIterator.DefaultPageSize, CommandFlags flags = CommandFlags.None)
IEnumerable<RedisKey> IServer.Keys(int database, RedisValue pattern, int pageSize, CommandFlags flags)
{
return Keys(database, pattern, pageSize, 0, flags);
}
public IEnumerable<RedisKey> Keys(int database = 0, RedisValue pattern = default(RedisValue), int pageSize = CursorEnumerableBase.DefaultPageSize, long cursor = 0, CommandFlags flags = CommandFlags.None)
{
if (pageSize <= 0) throw new ArgumentOutOfRangeException("pageSize");
if (KeysScanIterator.IsNil(pattern)) pattern = RedisLiterals.Wildcard;
if (CursorEnumerableBase.IsNil(pattern)) pattern = RedisLiterals.Wildcard;
if (multiplexer.CommandMap.IsAvailable(RedisCommand.SCAN))
{
var features = server.GetFeatures();
if (features.Scan) return new KeysScanIterator(this, database, pattern, pageSize, flags).Read();
if (features.Scan) return new KeysScanEnumerable(this, database, pattern, pageSize, cursor, flags);
}
if (cursor != 0) throw new InvalidOperationException("A cursor cannot be used with KEYS");
Message msg = Message.Create(database, flags, RedisCommand.KEYS, pattern);
return ExecuteSync(msg, ResultProcessor.RedisKeyArray);
}
......@@ -604,38 +611,6 @@ ResultProcessor<bool> GetSaveResultProcessor(SaveType type)
}
}
struct KeysScanResult
{
public static readonly ResultProcessor<KeysScanResult> Processor = new KeysResultProcessor();
public readonly long Cursor;
public readonly RedisKey[] Keys;
public KeysScanResult(long cursor, RedisKey[] keys)
{
this.Cursor = cursor;
this.Keys = keys;
}
private class KeysResultProcessor : ResultProcessor<KeysScanResult>
{
protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result)
{
switch (result.Type)
{
case ResultType.MultiBulk:
var arr = result.GetItems();
long i64;
if (arr.Length == 2 && arr[1].Type == ResultType.MultiBulk && arr[0].TryGetInt64(out i64))
{
var keysResult = new KeysScanResult(i64, arr[1].GetItemsAsKeys());
SetResult(message, keysResult);
return true;
}
break;
}
return false;
}
}
}
static class ScriptHash
{
static readonly byte[] hex = {
......@@ -664,64 +639,19 @@ public static RedisValue Hash(string value)
}
}
}
sealed class KeysScanIterator
{
internal const int DefaultPageSize = 10;
private readonly int db;
private readonly CommandFlags flags;
private readonly int pageSize;
sealed class KeysScanEnumerable : CursorEnumerableBase<RedisKey>
{
private readonly RedisValue pattern;
private readonly RedisServer server;
public KeysScanIterator(RedisServer server, int db, RedisValue pattern, int pageSize, CommandFlags flags)
public KeysScanEnumerable(RedisServer server, int db, RedisValue pattern, int pageSize, long cursor, CommandFlags flags)
: base(server, server.server, db, pageSize, cursor, flags)
{
this.pageSize = pageSize;
this.db = db;
this.pattern = pattern;
this.flags = flags;
this.server = server;
}
public static bool IsNil(RedisValue pattern)
protected override Message CreateMessage(long cursor)
{
if (pattern.IsNullOrEmpty) return true;
if (pattern.IsInteger) return false;
byte[] rawValue = pattern;
return rawValue.Length == 1 && rawValue[0] == '*';
}
public IEnumerable<RedisKey> Read()
{
var msg = CreateMessage(0, false);
KeysScanResult current = server.ExecuteSync(msg, KeysScanResult.Processor);
Task<KeysScanResult> pending;
do
{
// kick off the next immediately, but don't wait for it yet
msg = CreateMessage(current.Cursor, true);
pending = msg == null ? null : server.ExecuteAsync(msg, KeysScanResult.Processor);
// now we can iterate the rows
var keys = current.Keys;
for (int i = 0; i < keys.Length; i++)
yield return keys[i];
// wait for the next, if any
if (pending != null)
{
current = server.Wait(pending);
}
} while (pending != null);
}
Message CreateMessage(long cursor, bool running)
{
if (cursor == 0 && running) return null; // end of the line
if (IsNil(pattern))
{
if (pageSize == DefaultPageSize)
......@@ -745,6 +675,32 @@ Message CreateMessage(long cursor, bool running)
}
}
}
protected override ResultProcessor<ScanResult> Processor
{
get { return processor; }
}
public static readonly ResultProcessor<ScanResult> processor = new KeysResultProcessor();
private class KeysResultProcessor : ResultProcessor<ScanResult>
{
protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result)
{
switch (result.Type)
{
case ResultType.MultiBulk:
var arr = result.GetItems();
long i64;
if (arr.Length == 2 && arr[1].Type == ResultType.MultiBulk && arr[0].TryGetInt64(out i64))
{
var keysResult = new ScanResult(i64, arr[1].GetItemsAsKeys());
SetResult(message, keysResult);
return true;
}
break;
}
return false;
}
}
}
#region Sentinel
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment