Commit 0cc1d8f9 authored by Marc Gravell's avatar Marc Gravell

Explicit control of pattern mode;l fixes issue #114

parent b6e4bd65
...@@ -11,6 +11,48 @@ namespace StackExchange.Redis.Tests ...@@ -11,6 +11,48 @@ namespace StackExchange.Redis.Tests
public class PubSub : TestBase public class PubSub : TestBase
{ {
[Test]
public void ExplicitPublishMode()
{
using(var mx = Create(channelPrefix: "foo:"))
{
var pub = mx.GetSubscriber();
int a = 0, b = 0, c = 0, d = 0;
pub.Subscribe(new RedisChannel("*bcd", RedisChannel.PatternMode.Literal), (x, y) =>
{
Interlocked.Increment(ref a);
});
pub.Subscribe(new RedisChannel("a*cd", RedisChannel.PatternMode.Pattern), (x, y) =>
{
Interlocked.Increment(ref b);
});
pub.Subscribe(new RedisChannel("ab*d", RedisChannel.PatternMode.Auto), (x, y) =>
{
Interlocked.Increment(ref c);
});
pub.Subscribe("abc*", (x, y) =>
{
Interlocked.Increment(ref d);
});
Thread.Sleep(1000);
pub.Publish("abcd", "efg");
Thread.Sleep(500);
Assert.AreEqual(0, Thread.VolatileRead(ref a), "a1");
Assert.AreEqual(1, Thread.VolatileRead(ref b), "b1");
Assert.AreEqual(1, Thread.VolatileRead(ref c), "c1");
Assert.AreEqual(1, Thread.VolatileRead(ref d), "d1");
pub.Publish("*bcd", "efg");
Thread.Sleep(500);
Assert.AreEqual(1, Thread.VolatileRead(ref a), "a2");
//Assert.AreEqual(1, Thread.VolatileRead(ref b), "b2");
//Assert.AreEqual(1, Thread.VolatileRead(ref c), "c2");
//Assert.AreEqual(1, Thread.VolatileRead(ref d), "d2");
}
}
[Test] [Test]
[TestCase(true, null, false)] [TestCase(true, null, false)]
[TestCase(false, null, false)] [TestCase(false, null, false)]
......
...@@ -740,7 +740,7 @@ void MatchResult(RawResult result) ...@@ -740,7 +740,7 @@ void MatchResult(RawResult result)
} }
// invoke the handlers // invoke the handlers
var channel = items[1].AsRedisChannel(ChannelPrefix); var channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal);
multiplexer.Trace("MESSAGE: " + channel, physicalName); multiplexer.Trace("MESSAGE: " + channel, physicalName);
if (!channel.IsNull) if (!channel.IsNull)
{ {
...@@ -750,11 +750,11 @@ void MatchResult(RawResult result) ...@@ -750,11 +750,11 @@ void MatchResult(RawResult result)
} }
else if (items.Length >= 4 && items[0].IsEqual(pmessage)) else if (items.Length >= 4 && items[0].IsEqual(pmessage))
{ {
var channel = items[2].AsRedisChannel(ChannelPrefix); var channel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Literal);
multiplexer.Trace("PMESSAGE: " + channel, physicalName); multiplexer.Trace("PMESSAGE: " + channel, physicalName);
if (!channel.IsNull) if (!channel.IsNull)
{ {
var sub = items[1].AsRedisChannel(ChannelPrefix); var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern);
multiplexer.OnMessage(sub, channel, items[3].AsRedisValue()); multiplexer.OnMessage(sub, channel, items[3].AsRedisValue());
} }
return; // AND STOP PROCESSING! return; // AND STOP PROCESSING!
......
...@@ -68,7 +68,7 @@ public override string ToString() ...@@ -68,7 +68,7 @@ public override string ToString()
return "(unknown)"; return "(unknown)";
} }
} }
internal RedisChannel AsRedisChannel(byte[] channelPrefix) internal RedisChannel AsRedisChannel(byte[] channelPrefix, RedisChannel.PatternMode mode)
{ {
switch (resultType) switch (resultType)
{ {
...@@ -76,7 +76,7 @@ internal RedisChannel AsRedisChannel(byte[] channelPrefix) ...@@ -76,7 +76,7 @@ internal RedisChannel AsRedisChannel(byte[] channelPrefix)
case ResultType.BulkString: case ResultType.BulkString:
if (channelPrefix == null) if (channelPrefix == null)
{ {
return (RedisChannel)GetBlob(); return new RedisChannel(GetBlob(), mode);
} }
if (AssertStarts(channelPrefix)) if (AssertStarts(channelPrefix))
{ {
...@@ -84,7 +84,7 @@ internal RedisChannel AsRedisChannel(byte[] channelPrefix) ...@@ -84,7 +84,7 @@ internal RedisChannel AsRedisChannel(byte[] channelPrefix)
byte[] copy = new byte[count - channelPrefix.Length]; byte[] copy = new byte[count - channelPrefix.Length];
Buffer.BlockCopy(src, offset + channelPrefix.Length, copy, 0, copy.Length); Buffer.BlockCopy(src, offset + channelPrefix.Length, copy, 0, copy.Length);
return (RedisChannel)copy; return new RedisChannel(copy, mode);
} }
return default(RedisChannel); return default(RedisChannel);
default: default:
......
...@@ -13,9 +13,36 @@ public struct RedisChannel : IEquatable<RedisChannel> ...@@ -13,9 +13,36 @@ public struct RedisChannel : IEquatable<RedisChannel>
private readonly byte[] value; private readonly byte[] value;
private RedisChannel(byte[] value) /// <summary>
/// Create a new redis channel from a buffer, explicitly controlling the pattern mode
/// </summary>
public RedisChannel(byte[] value, PatternMode mode) : this(value, DeterminePatternBased(value, mode))
{
}
/// <summary>
/// Create a new redis channel from a string, explicitly controlling the pattern mode
/// </summary>
public RedisChannel(string value, PatternMode mode) : this(value == null ? null : Encoding.UTF8.GetBytes(value), mode)
{
}
private RedisChannel(byte[] value, bool isPatternBased)
{ {
this.value = value; this.value = value;
this.IsPatternBased = isPatternBased;
}
private static bool DeterminePatternBased(byte[] value, PatternMode mode)
{
switch (mode)
{
case PatternMode.Auto:
return value != null && Array.IndexOf(value, (byte)'*') >= 0;
case PatternMode.Literal: return false;
case PatternMode.Pattern: return true;
default:
throw new ArgumentOutOfRangeException("mode");
}
} }
/// <summary> /// <summary>
...@@ -81,7 +108,7 @@ internal bool IsNull ...@@ -81,7 +108,7 @@ internal bool IsNull
/// </summary> /// </summary>
public static bool operator ==(RedisChannel x, RedisChannel y) public static bool operator ==(RedisChannel x, RedisChannel y)
{ {
return RedisValue.Equals(x.value, y.value); return x.IsPatternBased == y.IsPatternBased && RedisValue.Equals(x.value, y.value);
} }
/// <summary> /// <summary>
...@@ -141,7 +168,8 @@ public override bool Equals(object obj) ...@@ -141,7 +168,8 @@ public override bool Equals(object obj)
/// </summary> /// </summary>
public bool Equals(RedisChannel other) public bool Equals(RedisChannel other)
{ {
return RedisValue.Equals(this.value, other.value); return this.IsPatternBased == other.IsPatternBased &&
RedisValue.Equals(this.value, other.value);
} }
/// <summary> /// <summary>
...@@ -149,7 +177,7 @@ public bool Equals(RedisChannel other) ...@@ -149,7 +177,7 @@ public bool Equals(RedisChannel other)
/// </summary> /// </summary>
public override int GetHashCode() public override int GetHashCode()
{ {
return RedisValue.GetHashCode(this.value); return RedisValue.GetHashCode(this.value) + (IsPatternBased ? 1 : 0);
} }
/// <summary> /// <summary>
...@@ -180,9 +208,25 @@ internal RedisChannel Clone() ...@@ -180,9 +208,25 @@ internal RedisChannel Clone()
return clone; return clone;
} }
internal bool Contains(byte value) internal readonly bool IsPatternBased;
/// <summary>
/// The matching pattern for this channel
/// </summary>
public enum PatternMode
{ {
return this.value != null && Array.IndexOf(this.value, value) >= 0; /// <summary>
/// Will be treated as a pattern if it includes *
/// </summary>
Auto = 0,
/// <summary>
/// Never a pattern
/// </summary>
Literal = 1,
/// <summary>
/// Always a pattern
/// </summary>
Pattern = 2
} }
/// <summary> /// <summary>
/// Create a channel name from a String /// Create a channel name from a String
...@@ -190,7 +234,7 @@ internal bool Contains(byte value) ...@@ -190,7 +234,7 @@ internal bool Contains(byte value)
public static implicit operator RedisChannel(string key) public static implicit operator RedisChannel(string key)
{ {
if (key == null) return default(RedisChannel); if (key == null) return default(RedisChannel);
return new RedisChannel(Encoding.UTF8.GetBytes(key)); return new RedisChannel(Encoding.UTF8.GetBytes(key), PatternMode.Auto);
} }
/// <summary> /// <summary>
/// Create a channel name from a Byte[] /// Create a channel name from a Byte[]
...@@ -198,7 +242,7 @@ internal bool Contains(byte value) ...@@ -198,7 +242,7 @@ internal bool Contains(byte value)
public static implicit operator RedisChannel(byte[] key) public static implicit operator RedisChannel(byte[] key)
{ {
if (key == null) return default(RedisChannel); if (key == null) return default(RedisChannel);
return new RedisChannel(key); return new RedisChannel(key, PatternMode.Auto);
} }
/// <summary> /// <summary>
/// Obtain the channel name as a Byte[] /// Obtain the channel name as a Byte[]
......
...@@ -449,14 +449,14 @@ public RedisChannel[] SubscriptionChannels(RedisChannel pattern = default(RedisC ...@@ -449,14 +449,14 @@ public RedisChannel[] SubscriptionChannels(RedisChannel pattern = default(RedisC
{ {
var msg = pattern.IsNullOrEmpty ? Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS) var msg = pattern.IsNullOrEmpty ? Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS)
: Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS, pattern); : Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS, pattern);
return ExecuteSync(msg, ResultProcessor.RedisChannelArray); return ExecuteSync(msg, ResultProcessor.RedisChannelArrayLiteral);
} }
public Task<RedisChannel[]> SubscriptionChannelsAsync(RedisChannel pattern = default(RedisChannel), CommandFlags flags = CommandFlags.None) public Task<RedisChannel[]> SubscriptionChannelsAsync(RedisChannel pattern = default(RedisChannel), CommandFlags flags = CommandFlags.None)
{ {
var msg = pattern.IsNullOrEmpty ? Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS) var msg = pattern.IsNullOrEmpty ? Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS)
: Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS, pattern); : Message.Create(-1, flags, RedisCommand.PUBSUB, RedisLiterals.CHANNELS, pattern);
return ExecuteAsync(msg, ResultProcessor.RedisChannelArray); return ExecuteAsync(msg, ResultProcessor.RedisChannelArrayLiteral);
} }
public long SubscriptionPatternCount(CommandFlags flags = CommandFlags.None) public long SubscriptionPatternCount(CommandFlags flags = CommandFlags.None)
......
...@@ -186,7 +186,7 @@ public bool Remove(Action<RedisChannel, RedisValue> value) ...@@ -186,7 +186,7 @@ public bool Remove(Action<RedisChannel, RedisValue> value)
} }
public Task SubscribeToServer(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall) public Task SubscribeToServer(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{ {
var cmd = channel.Contains((byte)'*') ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE; var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var selected = multiplexer.SelectServer(-1, cmd, CommandFlags.DemandMaster, default(RedisKey)); var selected = multiplexer.SelectServer(-1, cmd, CommandFlags.DemandMaster, default(RedisKey));
if (selected == null || Interlocked.CompareExchange(ref owner, selected, null) != null) return null; if (selected == null || Interlocked.CompareExchange(ref owner, selected, null) != null) return null;
...@@ -201,7 +201,7 @@ public Task UnsubscribeFromServer(RedisChannel channel, CommandFlags flags, obje ...@@ -201,7 +201,7 @@ public Task UnsubscribeFromServer(RedisChannel channel, CommandFlags flags, obje
var oldOwner = Interlocked.Exchange(ref owner, null); var oldOwner = Interlocked.Exchange(ref owner, null);
if (oldOwner == null) return null; if (oldOwner == null) return null;
var cmd = channel.Contains((byte)'*') ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE; var cmd = channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE;
var msg = Message.Create(-1, flags, cmd, channel); var msg = Message.Create(-1, flags, cmd, channel);
if (internalCall) msg.SetInternalCall(); if (internalCall) msg.SetInternalCall();
return oldOwner.QueueDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState); return oldOwner.QueueDirectAsync(msg, ResultProcessor.TrackSubscriptions, asyncState);
...@@ -215,7 +215,7 @@ internal void Resubscribe(RedisChannel channel, ServerEndPoint server) ...@@ -215,7 +215,7 @@ internal void Resubscribe(RedisChannel channel, ServerEndPoint server)
{ {
if (server != null && Interlocked.CompareExchange(ref owner, server, server) == server) if (server != null && Interlocked.CompareExchange(ref owner, server, server) == server)
{ {
var cmd = channel.Contains((byte)'*') ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE; var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel); var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel);
msg.SetInternalCall(); msg.SetInternalCall();
server.QueueDirectFireAndForget(msg, ResultProcessor.TrackSubscriptions); server.QueueDirectFireAndForget(msg, ResultProcessor.TrackSubscriptions);
......
...@@ -50,7 +50,7 @@ abstract class ResultProcessor ...@@ -50,7 +50,7 @@ abstract class ResultProcessor
NullableInt64 = new NullableInt64Processor(); NullableInt64 = new NullableInt64Processor();
public static readonly ResultProcessor<RedisChannel[]> public static readonly ResultProcessor<RedisChannel[]>
RedisChannelArray = new RedisChannelArrayProcessor(); RedisChannelArrayLiteral = new RedisChannelArrayProcessor(RedisChannel.PatternMode.Literal);
public static readonly ResultProcessor<RedisKey> public static readonly ResultProcessor<RedisKey>
RedisKey = new RedisKeyProcessor(); RedisKey = new RedisKeyProcessor();
...@@ -971,6 +971,11 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -971,6 +971,11 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
sealed class RedisChannelArrayProcessor : ResultProcessor<RedisChannel[]> sealed class RedisChannelArrayProcessor : ResultProcessor<RedisChannel[]>
{ {
private readonly RedisChannel.PatternMode mode;
public RedisChannelArrayProcessor(RedisChannel.PatternMode mode)
{
this.mode = mode;
}
protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result) protected override bool SetResultCore(PhysicalConnection connection, Message message, RawResult result)
{ {
switch (result.Type) switch (result.Type)
...@@ -988,7 +993,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes ...@@ -988,7 +993,7 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
byte[] channelPrefix = connection.ChannelPrefix; byte[] channelPrefix = connection.ChannelPrefix;
for (int i = 0; i < final.Length; i++) for (int i = 0; i < final.Length; i++)
{ {
final[i] = arr[i].AsRedisChannel(channelPrefix); final[i] = arr[i].AsRedisChannel(channelPrefix, mode);
} }
} }
SetResult(message, final); SetResult(message, final);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment