Added support for table name attribute on POCO class, should work with ANY...

Added support for table name attribute on POCO class, should work with ANY attribute named "TableAttribute" that has a "Name" property for it's value (supplied one, and also tested with the Table-attribute in EntityFramework).  Cleaned up ProxyGenerator for readability.
parent 3f411120
...@@ -42,9 +42,6 @@ ...@@ -42,9 +42,6 @@
<WarningLevel>4</WarningLevel> <WarningLevel>4</WarningLevel>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<Reference Include="EntityFramework">
<HintPath>Dependencies\EntityFramework.dll</HintPath>
</Reference>
<Reference Include="System" /> <Reference Include="System" />
<Reference Include="System.ComponentModel.DataAnnotations" /> <Reference Include="System.ComponentModel.DataAnnotations" />
<Reference Include="System.Core" /> <Reference Include="System.Core" />
......
...@@ -30,11 +30,8 @@ private static void Setup() ...@@ -30,11 +30,8 @@ private static void Setup()
using (var connection = new SqlCeConnection(connectionString)) using (var connection = new SqlCeConnection(connectionString))
{ {
connection.Open(); connection.Open();
var sql = connection.Execute(@" create table Users (Id int IDENTITY(1,1) not null, Name nvarchar(100) not null, Age int not null) ");
@" connection.Execute(@" create table Automobiles (Id int IDENTITY(1,1) not null, Name nvarchar(100) not null) ");
create table Users (Id int IDENTITY(1,1) not null, Name nvarchar(100) not null, Age int not null)
";
connection.Execute(sql);
} }
Console.WriteLine("Created database"); Console.WriteLine("Created database");
} }
......
using System; using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using System.Data; using System.Data;
using System.Data.SqlServerCe; using System.Data.SqlServerCe;
using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using Dapper.Contrib.Extensions; using Dapper.Contrib.Extensions;
namespace Dapper.Contrib.Tests namespace Dapper.Contrib.Tests
{ {
public interface IUser public interface IUser
...@@ -26,6 +24,13 @@ public class User : IUser ...@@ -26,6 +24,13 @@ public class User : IUser
public int Age { get; set; } public int Age { get; set; }
} }
[Table("Automobiles")]
public class Car
{
public int Id { get; set; }
public string Name { get; set; }
}
public class Tests public class Tests
{ {
private IDbConnection GetOpenConnection() private IDbConnection GetOpenConnection()
...@@ -38,7 +43,19 @@ private IDbConnection GetOpenConnection() ...@@ -38,7 +43,19 @@ private IDbConnection GetOpenConnection()
return connection; return connection;
} }
public void TableName()
{
using (var connection = GetOpenConnection())
{
// tests against "Automobiles" table (Table attribute)
connection.Insert(new Car {Name = "Volvo"});
connection.Get<Car>(1).Name.IsEqualTo("Volvo");
connection.Update(new Car() {Id = 1, Name = "Saab"}).IsEqualTo(true);
connection.Get<Car>(1).Name.IsEqualTo("Saab");
connection.Delete(new Car() {Id = 1}).IsEqualTo(true);
connection.Get<Car>(1).IsNull();
}
}
public void TestSimpleGet() public void TestSimpleGet()
{ {
......
...@@ -76,55 +76,39 @@ public static T GetInterfaceProxy<T>() ...@@ -76,55 +76,39 @@ public static T GetInterfaceProxy<T>()
private static MethodInfo CreateIsDirtyProperty(TypeBuilder typeBuilder) private static MethodInfo CreateIsDirtyProperty(TypeBuilder typeBuilder)
{ {
Type propType = typeof(bool); var propType = typeof(bool);
FieldBuilder field = typeBuilder.DefineField("_" + "IsDirty", propType, FieldAttributes.Private); var field = typeBuilder.DefineField("_" + "IsDirty", propType, FieldAttributes.Private);
// Generate a public property var property = typeBuilder.DefineProperty("IsDirty",
PropertyBuilder property =
typeBuilder.DefineProperty("IsDirty",
PropertyAttributes.None, PropertyAttributes.None,
propType, propType,
new Type[] { propType }); new Type[] { propType });
// The property set and property get methods require a special set of attributes: const MethodAttributes getSetAttr = MethodAttributes.Public | MethodAttributes.NewSlot | MethodAttributes.SpecialName |
MethodAttributes GetSetAttr =
MethodAttributes.Public | MethodAttributes.NewSlot | MethodAttributes.SpecialName |
MethodAttributes.Final | MethodAttributes.Virtual | MethodAttributes.HideBySig; MethodAttributes.Final | MethodAttributes.Virtual | MethodAttributes.HideBySig;
// Define the "get" accessor method for current private field. // Define the "get" and "set" accessor methods
MethodBuilder currGetPropMthdBldr = var currGetPropMthdBldr = typeBuilder.DefineMethod("get_" + "IsDirty",
typeBuilder.DefineMethod("get_" + "IsDirty", getSetAttr,
GetSetAttr,
propType, propType,
Type.EmptyTypes); Type.EmptyTypes);
var currGetIL = currGetPropMthdBldr.GetILGenerator();
// Intermediate Language stuff...
ILGenerator currGetIL = currGetPropMthdBldr.GetILGenerator();
currGetIL.Emit(OpCodes.Ldarg_0); currGetIL.Emit(OpCodes.Ldarg_0);
currGetIL.Emit(OpCodes.Ldfld, field); currGetIL.Emit(OpCodes.Ldfld, field);
currGetIL.Emit(OpCodes.Ret); currGetIL.Emit(OpCodes.Ret);
var currSetPropMthdBldr = typeBuilder.DefineMethod("set_" + "IsDirty",
// Define the "set" accessor method for current private field. getSetAttr,
MethodBuilder currSetPropMthdBldr =
typeBuilder.DefineMethod("set_" + "IsDirty",
GetSetAttr,
null, null,
new Type[] { propType }); new Type[] { propType });
var currSetIL = currSetPropMthdBldr.GetILGenerator();
// Again some Intermediate Language stuff...
ILGenerator currSetIL = currSetPropMthdBldr.GetILGenerator();
currSetIL.Emit(OpCodes.Ldarg_0); currSetIL.Emit(OpCodes.Ldarg_0);
currSetIL.Emit(OpCodes.Ldarg_1); currSetIL.Emit(OpCodes.Ldarg_1);
currSetIL.Emit(OpCodes.Stfld, field); currSetIL.Emit(OpCodes.Stfld, field);
currSetIL.Emit(OpCodes.Ret); currSetIL.Emit(OpCodes.Ret);
// Last, we must map the two methods created above to our PropertyBuilder to
// their corresponding behaviors, "get" and "set" respectively.
property.SetGetMethod(currGetPropMthdBldr); property.SetGetMethod(currGetPropMthdBldr);
property.SetSetMethod(currSetPropMthdBldr); property.SetSetMethod(currSetPropMthdBldr);
var getMethod = typeof(IProxy).GetMethod("get_" + "IsDirty");
MethodInfo getMethod = typeof(IProxy).GetMethod("get_" + "IsDirty"); var setMethod = typeof(IProxy).GetMethod("set_" + "IsDirty");
MethodInfo setMethod = typeof(IProxy).GetMethod("set_" + "IsDirty");
typeBuilder.DefineMethodOverride(currGetPropMthdBldr, getMethod); typeBuilder.DefineMethodOverride(currGetPropMthdBldr, getMethod);
typeBuilder.DefineMethodOverride(currSetPropMthdBldr, setMethod); typeBuilder.DefineMethodOverride(currSetPropMthdBldr, setMethod);
...@@ -133,42 +117,34 @@ private static MethodInfo CreateIsDirtyProperty(TypeBuilder typeBuilder) ...@@ -133,42 +117,34 @@ private static MethodInfo CreateIsDirtyProperty(TypeBuilder typeBuilder)
private static void CreateProperty<T>(TypeBuilder typeBuilder, string propertyName, Type propType, MethodInfo setIsDirtyMethod, bool isIdentity) private static void CreateProperty<T>(TypeBuilder typeBuilder, string propertyName, Type propType, MethodInfo setIsDirtyMethod, bool isIdentity)
{ {
FieldBuilder field = typeBuilder.DefineField("_" + propertyName, propType, FieldAttributes.Private); //Define the field and the property
// Generate a public property var field = typeBuilder.DefineField("_" + propertyName, propType, FieldAttributes.Private);
PropertyBuilder property = var property = typeBuilder.DefineProperty(propertyName,
typeBuilder.DefineProperty(propertyName,
PropertyAttributes.None, PropertyAttributes.None,
propType, propType,
new Type[] { propType }); new Type[] { propType });
// The property set and property get methods require a special set of attributes: const MethodAttributes getSetAttr = MethodAttributes.Public | MethodAttributes.Virtual |
MethodAttributes GetSetAttr =
MethodAttributes.Public | MethodAttributes.Virtual |
MethodAttributes.HideBySig; MethodAttributes.HideBySig;
// Define the "get" accessor method for current private field. // Define the "get" and "set" accessor methods
MethodBuilder currGetPropMthdBldr = var currGetPropMthdBldr = typeBuilder.DefineMethod("get_" + propertyName,
typeBuilder.DefineMethod("get_" + propertyName, getSetAttr,
GetSetAttr,
propType, propType,
Type.EmptyTypes); Type.EmptyTypes);
// Intermediate Language stuff... var currGetIL = currGetPropMthdBldr.GetILGenerator();
ILGenerator currGetIL = currGetPropMthdBldr.GetILGenerator();
currGetIL.Emit(OpCodes.Ldarg_0); currGetIL.Emit(OpCodes.Ldarg_0);
currGetIL.Emit(OpCodes.Ldfld, field); currGetIL.Emit(OpCodes.Ldfld, field);
currGetIL.Emit(OpCodes.Ret); currGetIL.Emit(OpCodes.Ret);
// Define the "set" accessor method for current private field. var currSetPropMthdBldr = typeBuilder.DefineMethod("set_" + propertyName,
MethodBuilder currSetPropMthdBldr = getSetAttr,
typeBuilder.DefineMethod("set_" + propertyName,
GetSetAttr,
null, null,
new Type[] { propType }); new Type[] { propType });
// Again some Intermediate Language stuff... //store value in private field and set the isdirty flag
ILGenerator currSetIL = currSetPropMthdBldr.GetILGenerator(); var currSetIL = currSetPropMthdBldr.GetILGenerator();
currSetIL.Emit(OpCodes.Ldarg_0); currSetIL.Emit(OpCodes.Ldarg_0);
currSetIL.Emit(OpCodes.Ldarg_1); currSetIL.Emit(OpCodes.Ldarg_1);
currSetIL.Emit(OpCodes.Stfld, field); currSetIL.Emit(OpCodes.Stfld, field);
...@@ -177,23 +153,19 @@ private static void CreateProperty<T>(TypeBuilder typeBuilder, string propertyNa ...@@ -177,23 +153,19 @@ private static void CreateProperty<T>(TypeBuilder typeBuilder, string propertyNa
currSetIL.Emit(OpCodes.Call, setIsDirtyMethod); currSetIL.Emit(OpCodes.Call, setIsDirtyMethod);
currSetIL.Emit(OpCodes.Ret); currSetIL.Emit(OpCodes.Ret);
//TODO: Should copy all attributes defined by the interface?
if (isIdentity) if (isIdentity)
{ {
Type keyAttribute = typeof(KeyAttribute); var keyAttribute = typeof(KeyAttribute);
// Create a Constructorinfo object for attribute 'MyAttribute1'. var myConstructorInfo = keyAttribute.GetConstructor(new Type[] { });
ConstructorInfo myConstructorInfo = keyAttribute.GetConstructor(new Type[] { }); var attributeBuilder = new CustomAttributeBuilder(myConstructorInfo, new object[] { });
// Create the CustomAttribute instance of attribute of type 'MyAttribute1'.
CustomAttributeBuilder attributeBuilder = new CustomAttributeBuilder(myConstructorInfo, new object[] { });
property.SetCustomAttribute(attributeBuilder); property.SetCustomAttribute(attributeBuilder);
} }
// Last, we must map the two methods created above to our PropertyBuilder to
// their corresponding behaviors, "get" and "set" respectively.
property.SetGetMethod(currGetPropMthdBldr); property.SetGetMethod(currGetPropMthdBldr);
property.SetSetMethod(currSetPropMthdBldr); property.SetSetMethod(currSetPropMthdBldr);
var getMethod = typeof(T).GetMethod("get_" + propertyName);
MethodInfo getMethod = typeof(T).GetMethod("get_" + propertyName); var setMethod = typeof(T).GetMethod("set_" + propertyName);
MethodInfo setMethod = typeof(T).GetMethod("set_" + propertyName);
typeBuilder.DefineMethodOverride(currGetPropMthdBldr, getMethod); typeBuilder.DefineMethodOverride(currGetPropMthdBldr, getMethod);
typeBuilder.DefineMethodOverride(currSetPropMthdBldr, setMethod); typeBuilder.DefineMethodOverride(currSetPropMthdBldr, setMethod);
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Data; using System.Data;
using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
...@@ -15,6 +16,7 @@ public static class SqlMapperExtensions ...@@ -15,6 +16,7 @@ public static class SqlMapperExtensions
private static readonly ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>> KeyProperties = new ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>>(); private static readonly ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>> KeyProperties = new ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>>();
private static readonly ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>> TypeProperties = new ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>>(); private static readonly ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>> TypeProperties = new ConcurrentDictionary<RuntimeTypeHandle, IEnumerable<PropertyInfo>>();
private static readonly ConcurrentDictionary<RuntimeTypeHandle, string> GetQueries = new ConcurrentDictionary<RuntimeTypeHandle, string>(); private static readonly ConcurrentDictionary<RuntimeTypeHandle, string> GetQueries = new ConcurrentDictionary<RuntimeTypeHandle, string>();
private static readonly ConcurrentDictionary<RuntimeTypeHandle, string> TypeTableName = new ConcurrentDictionary<RuntimeTypeHandle, string>();
private static IEnumerable<PropertyInfo> KeyPropertiesCache(Type type) private static IEnumerable<PropertyInfo> KeyPropertiesCache(Type type)
{ {
...@@ -73,13 +75,12 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type) ...@@ -73,13 +75,12 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type)
throw new DataException("Get<T> only supports en entity with a [Key] property"); throw new DataException("Get<T> only supports en entity with a [Key] property");
var onlyKey = keys.First(); var onlyKey = keys.First();
var name = type.Name;
if (type.IsInterface && name.StartsWith("I")) var name = GetTableName(type);
name = name.Substring(1);
// TODO: pluralizer // TODO: pluralizer
// TODO: query information schema and only select fields that are both in information schema and underlying class / interface // TODO: query information schema and only select fields that are both in information schema and underlying class / interface
sql = "select * from " + name + "s where " + onlyKey.Name + " = @id"; sql = "select * from " + name + " where " + onlyKey.Name + " = @id";
GetQueries[type.TypeHandle] = sql; GetQueries[type.TypeHandle] = sql;
} }
...@@ -112,7 +113,24 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type) ...@@ -112,7 +113,24 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type)
return obj; return obj;
} }
private static string GetTableName(Type type)
{
string name;
if (!TypeTableName.TryGetValue(type.TypeHandle, out name))
{
name = type.Name + "s";
if (type.IsInterface && name.StartsWith("I"))
name = name.Substring(1);
//NOTE: This as dynamic trick should be able to handle both our own Table-attribute as well as the one in EntityFramework
var tableattr = type.GetCustomAttributes(false).Where(attr => attr.GetType().Name == "TableAttribute").SingleOrDefault() as
dynamic;
if (tableattr != null)
name = tableattr.Name;
TypeTableName[type.TypeHandle] = name;
}
return name;
}
/// <summary> /// <summary>
/// Inserts an entity into table "Ts" and returns identity id. /// Inserts an entity into table "Ts" and returns identity id.
...@@ -120,16 +138,19 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type) ...@@ -120,16 +138,19 @@ private static IEnumerable<PropertyInfo> TypePropertiesCache(Type type)
/// <param name="connection">Open SqlConnection</param> /// <param name="connection">Open SqlConnection</param>
/// <param name="entityToInsert">Entity to insert</param> /// <param name="entityToInsert">Entity to insert</param>
/// <returns>Identity of inserted entity</returns> /// <returns>Identity of inserted entity</returns>
public static long Insert<T>(this IDbConnection connection, T entityToInsert) public static long Insert<T>(this IDbConnection connection, T entityToInsert) where T : class
{ {
using (var tx = connection.BeginTransaction()) using (var tx = connection.BeginTransaction())
{ {
var name = entityToInsert.GetType().Name; var type = typeof(T);
var name = GetTableName(type);
var sb = new StringBuilder(null); var sb = new StringBuilder(null);
sb.AppendFormat("insert into {0}s (", name); sb.AppendFormat("insert into {0} (", name);
var allProperties = TypePropertiesCache(typeof(T)); var allProperties = TypePropertiesCache(type);
var keyProperties = KeyPropertiesCache(typeof(T)); var keyProperties = KeyPropertiesCache(type);
for (var i = 0; i < allProperties.Count(); i++) for (var i = 0; i < allProperties.Count(); i++)
{ {
...@@ -166,7 +187,7 @@ public static long Insert<T>(this IDbConnection connection, T entityToInsert) ...@@ -166,7 +187,7 @@ public static long Insert<T>(this IDbConnection connection, T entityToInsert)
/// <param name="connection">Open SqlConnection</param> /// <param name="connection">Open SqlConnection</param>
/// <param name="entityToUpdate">Entity to be updated</param> /// <param name="entityToUpdate">Entity to be updated</param>
/// <returns>true if updated, false if not found or not modified (tracked entities)</returns> /// <returns>true if updated, false if not found or not modified (tracked entities)</returns>
public static bool Update<T>(this IDbConnection connection, T entityToUpdate) public static bool Update<T>(this IDbConnection connection, T entityToUpdate) where T : class
{ {
var proxy = entityToUpdate as IProxy; var proxy = entityToUpdate as IProxy;
if (proxy != null) if (proxy != null)
...@@ -180,12 +201,10 @@ public static bool Update<T>(this IDbConnection connection, T entityToUpdate) ...@@ -180,12 +201,10 @@ public static bool Update<T>(this IDbConnection connection, T entityToUpdate)
if (keyProperties.Count() == 0) if (keyProperties.Count() == 0)
throw new ArgumentException("Entity must have at least one [Key] property"); throw new ArgumentException("Entity must have at least one [Key] property");
var name = type.Name; var name = GetTableName(type);
if (type.IsInterface && name.StartsWith("I"))
name = name.Substring(1);
var sb = new StringBuilder(); var sb = new StringBuilder();
sb.AppendFormat("update {0}s set ", name); sb.AppendFormat("update {0} set ", name);
var allProperties = TypePropertiesCache(type); var allProperties = TypePropertiesCache(type);
var nonIdProps = allProperties.Where(a => !keyProperties.Contains(a)); var nonIdProps = allProperties.Where(a => !keyProperties.Contains(a));
...@@ -216,7 +235,7 @@ public static bool Update<T>(this IDbConnection connection, T entityToUpdate) ...@@ -216,7 +235,7 @@ public static bool Update<T>(this IDbConnection connection, T entityToUpdate)
/// <param name="connection">Open SqlConnection</param> /// <param name="connection">Open SqlConnection</param>
/// <param name="entityToDelete">Entity to delete</param> /// <param name="entityToDelete">Entity to delete</param>
/// <returns>true if deleted, false if not found</returns> /// <returns>true if deleted, false if not found</returns>
public static bool Delete<T>(this IDbConnection connection, T entityToDelete) public static bool Delete<T>(this IDbConnection connection, T entityToDelete) where T : class
{ {
var type = typeof(T); var type = typeof(T);
...@@ -224,12 +243,10 @@ public static bool Delete<T>(this IDbConnection connection, T entityToDelete) ...@@ -224,12 +243,10 @@ public static bool Delete<T>(this IDbConnection connection, T entityToDelete)
if (keyProperties.Count() == 0) if (keyProperties.Count() == 0)
throw new ArgumentException("Entity must have at least one [Key] property"); throw new ArgumentException("Entity must have at least one [Key] property");
var name = type.Name; var name = GetTableName(type);
if (type.IsInterface && name.StartsWith("I"))
name = name.Substring(1);
var sb = new StringBuilder(); var sb = new StringBuilder();
sb.AppendFormat("delete from {0}s where ", name); sb.AppendFormat("delete from {0} where ", name);
for (var i = 0; i < keyProperties.Count(); i++) for (var i = 0; i < keyProperties.Count(); i++)
{ {
...@@ -242,4 +259,14 @@ public static bool Delete<T>(this IDbConnection connection, T entityToDelete) ...@@ -242,4 +259,14 @@ public static bool Delete<T>(this IDbConnection connection, T entityToDelete)
return deleted > 0; return deleted > 0;
} }
} }
[AttributeUsage(AttributeTargets.Class)]
public class TableAttribute : Attribute
{
public TableAttribute(string tableName)
{
Name = tableName;
}
public string Name { get; private set; }
}
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Dapper.Contrib namespace Dapper.Contrib.Extensions
{ {
public static class TypeExtension public static class TypeExtension
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment