Commit 3d3ce558 authored by Nick Craver's avatar Nick Craver

Merge pull request #428 from joncloud/contribArrayIssue

Added checks to make sure that arrays get the appropriate type when d…
parents fa03d54b a967b0c2
...@@ -135,7 +135,12 @@ public static partial class SqlMapperExtensions ...@@ -135,7 +135,12 @@ public static partial class SqlMapperExtensions
sqlAdapter = GetFormatter(connection); sqlAdapter = GetFormatter(connection);
var isList = false; var isList = false;
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
isList = true;
type = type.GetElementType();
}
else if (type.IsGenericType())
{ {
isList = true; isList = true;
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
...@@ -195,8 +200,14 @@ public static partial class SqlMapperExtensions ...@@ -195,8 +200,14 @@ public static partial class SqlMapperExtensions
var type = typeof(T); var type = typeof(T);
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
type = type.GetElementType();
}
else if (type.IsGenericType())
{
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
}
var keyProperties = KeyPropertiesCache(type); var keyProperties = KeyPropertiesCache(type);
var explicitKeyProperties = ExplicitKeyPropertiesCache(type); var explicitKeyProperties = ExplicitKeyPropertiesCache(type);
...@@ -250,8 +261,14 @@ public static partial class SqlMapperExtensions ...@@ -250,8 +261,14 @@ public static partial class SqlMapperExtensions
var type = typeof(T); var type = typeof(T);
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
type = type.GetElementType();
}
else if (type.IsGenericType())
{
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
}
var keyProperties = KeyPropertiesCache(type); var keyProperties = KeyPropertiesCache(type);
var explicitKeyProperties = ExplicitKeyPropertiesCache(type); var explicitKeyProperties = ExplicitKeyPropertiesCache(type);
......
...@@ -295,7 +295,12 @@ private static string GetTableName(Type type) ...@@ -295,7 +295,12 @@ private static string GetTableName(Type type)
var type = typeof(T); var type = typeof(T);
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
isList = true;
type = type.GetElementType();
}
else if (type.IsGenericType())
{ {
isList = true; isList = true;
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
...@@ -365,8 +370,14 @@ private static string GetTableName(Type type) ...@@ -365,8 +370,14 @@ private static string GetTableName(Type type)
var type = typeof(T); var type = typeof(T);
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
type = type.GetElementType();
}
else if (type.IsGenericType())
{
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
}
var keyProperties = KeyPropertiesCache(type).ToList(); //added ToList() due to issue #418, must work on a list copy var keyProperties = KeyPropertiesCache(type).ToList(); //added ToList() due to issue #418, must work on a list copy
var explicitKeyProperties = ExplicitKeyPropertiesCache(type); var explicitKeyProperties = ExplicitKeyPropertiesCache(type);
...@@ -420,8 +431,14 @@ private static string GetTableName(Type type) ...@@ -420,8 +431,14 @@ private static string GetTableName(Type type)
var type = typeof(T); var type = typeof(T);
if (type.IsArray || type.IsGenericType()) if (type.IsArray)
{
type = type.GetElementType();
}
else if (type.IsGenericType())
{
type = type.GetGenericArguments()[0]; type = type.GetGenericArguments()[0];
}
var keyProperties = KeyPropertiesCache(type).ToList(); //added ToList() due to issue #418, must work on a list copy var keyProperties = KeyPropertiesCache(type).ToList(); //added ToList() due to issue #418, must work on a list copy
var explicitKeyProperties = ExplicitKeyPropertiesCache(type); var explicitKeyProperties = ExplicitKeyPropertiesCache(type);
......
...@@ -194,8 +194,20 @@ public async Task BuilderTemplateWithoutCompositionAsync() ...@@ -194,8 +194,20 @@ public async Task BuilderTemplateWithoutCompositionAsync()
} }
} }
[Fact]
public async Task InsertArrayAsync()
{
await InsertHelperAsync(src => src.ToArray());
}
[Fact] [Fact]
public async Task InsertListAsync() public async Task InsertListAsync()
{
await InsertHelperAsync(src => src.ToList());
}
private async Task InsertHelperAsync<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -207,15 +219,27 @@ public async Task InsertListAsync() ...@@ -207,15 +219,27 @@ public async Task InsertListAsync()
{ {
await connection.DeleteAllAsync<User>(); await connection.DeleteAllAsync<User>();
var total = await connection.InsertAsync(users); var total = await connection.InsertAsync(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
} }
} }
[Fact]
public async Task UpdateArrayAsync()
{
await UpdateHelperAsync(src => src.ToArray());
}
[Fact] [Fact]
public async Task UpdateListAsync() public async Task UpdateListAsync()
{
await UpdateHelperAsync(src => src.ToList());
}
private async Task UpdateHelperAsync<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -227,7 +251,7 @@ public async Task UpdateListAsync() ...@@ -227,7 +251,7 @@ public async Task UpdateListAsync()
{ {
await connection.DeleteAllAsync<User>(); await connection.DeleteAllAsync<User>();
var total = await connection.InsertAsync(users); var total = await connection.InsertAsync(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
...@@ -235,14 +259,26 @@ public async Task UpdateListAsync() ...@@ -235,14 +259,26 @@ public async Task UpdateListAsync()
{ {
user.Name = user.Name + " updated"; user.Name = user.Name + " updated";
} }
await connection.UpdateAsync(users); await connection.UpdateAsync(helper(users));
var name = connection.Query<User>("select * from users").First().Name; var name = connection.Query<User>("select * from users").First().Name;
name.Contains("updated").IsTrue(); name.Contains("updated").IsTrue();
} }
} }
[Fact]
public async Task DeleteArrayAsync()
{
await DeleteHelperAsync(src => src.ToArray());
}
[Fact] [Fact]
public async Task DeleteListAsync() public async Task DeleteListAsync()
{
await DeleteHelperAsync(src => src.ToList());
}
private async Task DeleteHelperAsync<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -254,13 +290,13 @@ public async Task DeleteListAsync() ...@@ -254,13 +290,13 @@ public async Task DeleteListAsync()
{ {
await connection.DeleteAllAsync<User>(); await connection.DeleteAllAsync<User>();
var total = await connection.InsertAsync(users); var total = await connection.InsertAsync(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
var usersToDelete = users.Take(10).ToList(); var usersToDelete = users.Take(10).ToList();
await connection.DeleteAsync(usersToDelete); await connection.DeleteAsync(helper(usersToDelete));
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities - 10); users.Count.IsEqualTo(numberOfEntities - 10);
} }
......
...@@ -280,8 +280,20 @@ public void TestClosedConnection() ...@@ -280,8 +280,20 @@ public void TestClosedConnection()
} }
} }
[Fact]
public void InsertArray()
{
InsertHelper(src => src.ToArray());
}
[Fact] [Fact]
public void InsertList() public void InsertList()
{
InsertHelper(src => src.ToList());
}
private void InsertHelper<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -293,15 +305,27 @@ public void InsertList() ...@@ -293,15 +305,27 @@ public void InsertList()
{ {
connection.DeleteAll<User>(); connection.DeleteAll<User>();
var total = connection.Insert(users); var total = connection.Insert(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
} }
} }
[Fact]
public void UpdateArray()
{
UpdateHelper(src => src.ToArray());
}
[Fact] [Fact]
public void UpdateList() public void UpdateList()
{
UpdateHelper(src => src.ToList());
}
private void UpdateHelper<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -313,7 +337,7 @@ public void UpdateList() ...@@ -313,7 +337,7 @@ public void UpdateList()
{ {
connection.DeleteAll<User>(); connection.DeleteAll<User>();
var total = connection.Insert(users); var total = connection.Insert(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
...@@ -321,14 +345,26 @@ public void UpdateList() ...@@ -321,14 +345,26 @@ public void UpdateList()
{ {
user.Name = user.Name + " updated"; user.Name = user.Name + " updated";
} }
connection.Update(users); connection.Update(helper(users));
var name = connection.Query<User>("select * from users").First().Name; var name = connection.Query<User>("select * from users").First().Name;
name.Contains("updated").IsTrue(); name.Contains("updated").IsTrue();
} }
} }
[Fact]
public void DeleteArray()
{
DeleteHelper(src => src.ToArray());
}
[Fact] [Fact]
public void DeleteList() public void DeleteList()
{
DeleteHelper(src => src.ToList());
}
private void DeleteHelper<T>(Func<IEnumerable<User>, T> helper)
where T : class
{ {
const int numberOfEntities = 10; const int numberOfEntities = 10;
...@@ -340,17 +376,16 @@ public void DeleteList() ...@@ -340,17 +376,16 @@ public void DeleteList()
{ {
connection.DeleteAll<User>(); connection.DeleteAll<User>();
var total = connection.Insert(users); var total = connection.Insert(helper(users));
total.IsEqualTo(numberOfEntities); total.IsEqualTo(numberOfEntities);
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities); users.Count.IsEqualTo(numberOfEntities);
var usersToDelete = users.Take(10).ToList(); var usersToDelete = users.Take(10).ToList();
connection.Delete(usersToDelete); connection.Delete(helper(usersToDelete));
users = connection.Query<User>("select * from users").ToList(); users = connection.Query<User>("select * from users").ToList();
users.Count.IsEqualTo(numberOfEntities - 10); users.Count.IsEqualTo(numberOfEntities - 10);
} }
} }
[Fact] [Fact]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment