Unverified Commit 9fbba78d authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Make Dapper happy with Microsoft.Data.SqlClient (#1262)

* prep work for adding Microsoft.Data.SqlClient support/tests; SqlClient dependency is not yet removed
parent 33d974a5
......@@ -5,7 +5,7 @@
<Title>Dapper (Strong Named)</Title>
<Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description>
<Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0</TargetFrameworks>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0;netcoreapp2.1</TargetFrameworks>
<SignAssembly>true</SignAssembly>
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
</PropertyGroup>
......@@ -20,9 +20,15 @@
<Reference Include="Microsoft.CSharp" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" />
<!-- it would be nice to use System.Data.Common here, but we need SqlClient for SqlDbType in 1.3, and legacy SqlDataRecord API-->
<!--<PackageReference Include="System.Data.Common" Version="4.3.0" />-->
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' ">
<PackageReference Include="System.Collections.Concurrent" Version="4.3.0" />
......
......@@ -21,5 +21,6 @@
<PackageReference Include="MySqlConnector" Version="0.44.1" />
<PackageReference Include="xunit" Version="$(xUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
</Project>
......@@ -4,15 +4,31 @@
using System;
using System.Threading.Tasks;
using System.Threading;
using System.Data.SqlClient;
using Xunit;
using System.Data.Common;
namespace Dapper.Tests
{
public class Tests : TestBase
[Collection(NonParallelDefinition.Name)]
public sealed class SystemSqlClientAsyncTests : AsyncTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientAsyncTests : AsyncTests<MicrosoftSqlClientProvider> { }
#endif
[Collection(NonParallelDefinition.Name)]
public sealed class SystemSqlClientAsyncQueryCacheTests : AsyncQueryCacheTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientAsyncQueryCacheTests : AsyncQueryCacheTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class AsyncTests<TProvider> : TestBase<TProvider> where TProvider : SqlServerDatabaseProvider
{
private SqlConnection _marsConnection;
private SqlConnection MarsConnection => _marsConnection ?? (_marsConnection = GetOpenConnection(true));
private DbConnection _marsConnection;
private DbConnection MarsConnection => _marsConnection ?? (_marsConnection = Provider.GetOpenConnection(true));
[Fact]
public async Task TestBasicStringUsageAsync()
......@@ -100,7 +116,7 @@ public void TestLongOperationWithCancellation()
}
catch (AggregateException agg)
{
Assert.True(agg.InnerException is SqlException);
Assert.True(agg.InnerException.GetType().Name == "SqlException");
}
}
......@@ -382,38 +398,6 @@ public void RunSequentialVersusParallelSync()
Console.WriteLine("Pipeline: {0}ms", watch.ElapsedMilliseconds);
}
[Collection(NonParallelDefinition.Name)]
public class AsyncQueryCacheTests : TestBase
{
private SqlConnection _marsConnection;
private SqlConnection MarsConnection => _marsConnection ?? (_marsConnection = GetOpenConnection(true));
[Fact]
public void AssertNoCacheWorksForQueryMultiple()
{
const int a = 123, b = 456;
var cmdDef = new CommandDefinition("select @a; select @b;", new
{
a,
b
}, commandType: CommandType.Text, flags: CommandFlags.NoCache);
int c, d;
SqlMapper.PurgeQueryCache();
int before = SqlMapper.GetCachedSQLCount();
using (var multi = MarsConnection.QueryMultiple(cmdDef))
{
c = multi.Read<int>().Single();
d = multi.Read<int>().Single();
}
int after = SqlMapper.GetCachedSQLCount();
Assert.Equal(0, before);
Assert.Equal(0, after);
Assert.Equal(123, c);
Assert.Equal(456, d);
}
}
private class BasicType
{
public string Value { get; set; }
......@@ -827,7 +811,46 @@ public async Task Issue563_QueryAsyncShouldThrowException()
var data = (await connection.QueryAsync<int>("select 1 union all select 2; RAISERROR('after select', 16, 1);").ConfigureAwait(false)).ToList();
Assert.True(false, "Expected Exception");
}
catch (SqlException ex) when (ex.Message == "after select") { /* swallow only this */ }
catch (Exception ex) when (ex.GetType().Name == "SqlException" && ex.Message == "after select") { /* swallow only this */ }
}
}
[Collection(NonParallelDefinition.Name)]
public abstract class AsyncQueryCacheTests<TProvider> : TestBase<TProvider> where TProvider : SqlServerDatabaseProvider
{
private DbConnection _marsConnection;
private DbConnection MarsConnection => _marsConnection ?? (_marsConnection = Provider.GetOpenConnection(true));
public override void Dispose()
{
_marsConnection?.Dispose();
_marsConnection = null;
base.Dispose();
}
[Fact]
public void AssertNoCacheWorksForQueryMultiple()
{
const int a = 123, b = 456;
var cmdDef = new CommandDefinition("select @a; select @b;", new
{
a,
b
}, commandType: CommandType.Text, flags: CommandFlags.NoCache);
int c, d;
SqlMapper.PurgeQueryCache();
int before = SqlMapper.GetCachedSQLCount();
using (var multi = MarsConnection.QueryMultiple(cmdDef))
{
c = multi.Read<int>().Single();
d = multi.Read<int>().Single();
}
int after = SqlMapper.GetCachedSQLCount();
Assert.Equal(0, before);
Assert.Equal(0, after);
Assert.Equal(123, c);
Assert.Equal(456, d);
}
}
}
......@@ -5,7 +5,12 @@
namespace Dapper.Tests
{
public class ConstructorTests : TestBase
public sealed class SystemSqlClientConstructorTests : ConstructorTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientConstructorTests : ConstructorTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ConstructorTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestAbstractInheritance()
......
......@@ -6,16 +6,24 @@
<GenerateDocumentationFile>false</GenerateDocumentationFile>
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
<TargetFrameworks>net452;netcoreapp1.0;netcoreapp2.0</TargetFrameworks>
<TargetFrameworks>netcoreapp2.1;net46;netcoreapp2.0;net472</TargetFrameworks>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
</PropertyGroup>
<PropertyGroup Condition=" '$(TargetFramework)' == 'net452' ">
<PropertyGroup Condition="'$(TargetFramework)' == 'net46' OR '$(TargetFramework)' == 'net472'">
<DefineConstants>$(DefineConstants);ENTITY_FRAMEWORK;LINQ2SQL;OLEDB</DefineConstants>
</PropertyGroup>
<ItemGroup>
<None Remove="Test.DB.sdf" />
</ItemGroup>
<!-- these two go together
<PropertyGroup>
<DefineConstants>$(DefineConstants);MSSQLCLIENT</DefineConstants>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" Version="1.0.19128.1-Preview" />
</ItemGroup>
-->
<ItemGroup>
<ProjectReference Include="..\Dapper\Dapper.csproj" />
<ProjectReference Include="..\Dapper.Contrib\Dapper.Contrib.csproj" />
......@@ -26,9 +34,10 @@
<PackageReference Include="System.ValueTuple" Version="4.4.0" />
<PackageReference Include="xunit" Version="$(xUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net452'">
<ItemGroup Condition="'$(TargetFramework)' == 'net46' OR '$(TargetFramework)' == 'net472'">
<ProjectReference Include="..\Dapper.EntityFramework\Dapper.EntityFramework.csproj" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="1.1.1" />
<PackageReference Include="Microsoft.SqlServer.Types" Version="14.0.314.76" />
......@@ -48,6 +57,9 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.0'">
<PackageReference Include="Microsoft.Data.Sqlite" Version="2.0.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="Microsoft.Data.Sqlite" Version="2.0.0" />
</ItemGroup>
<PropertyGroup>
<PostBuildEvent>
......
......@@ -4,7 +4,12 @@
namespace Dapper.Tests
{
public class DataReaderTests : TestBase
public sealed class SystemSqlClientDataReaderTests : DataReaderTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientDataReaderTests : DataReaderTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class DataReaderTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void GetSameReaderForSameShape()
......
......@@ -5,7 +5,11 @@
namespace Dapper.Tests
{
public class DecimalTests : TestBase
public sealed class SystemSqlClientDecimalTests : DecimalTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientDecimalTests : DecimalTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class DecimalTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void Issue261_Decimals()
......
......@@ -4,7 +4,11 @@
namespace Dapper.Tests
{
public class EnumTests : TestBase
public sealed class SystemSqlClientEnumTests : EnumTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientEnumTests : EnumTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class EnumTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestEnumWeirdness()
......
using System;
using System.Data.SqlClient;
using Xunit;
namespace Dapper.Tests
......@@ -32,7 +31,7 @@ public FactRequiredCompatibilityLevelAttribute(int level) : base()
public static readonly int DetectedLevel;
static FactRequiredCompatibilityLevelAttribute()
{
using (var conn = TestBase.GetOpenConnection())
using (var conn = DatabaseProvider<SystemSqlClientProvider>.Instance.GetOpenConnection())
{
try
{
......@@ -57,15 +56,16 @@ public FactUnlessCaseSensitiveDatabaseAttribute() : base()
public static readonly bool IsCaseSensitive;
static FactUnlessCaseSensitiveDatabaseAttribute()
{
using (var conn = TestBase.GetOpenConnection())
using (var conn = DatabaseProvider<SystemSqlClientProvider>.Instance.GetOpenConnection())
{
try
{
conn.Execute("declare @i int; set @I = 1;");
}
catch (SqlException s)
catch (Exception ex) when (ex.GetType().Name == "SqlException")
{
if (s.Number == 137)
int err = ((dynamic)ex).Number;
if (err == 137)
IsCaseSensitive = true;
else
throw;
......
......@@ -3,7 +3,11 @@
namespace Dapper.Tests
{
public class LiteralTests : TestBase
public sealed class SystemSqlClientLiteralTests : LiteralTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientLiteralTests : LiteralTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class LiteralTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void LiteralReplacementEnumAndString()
......
......@@ -3,7 +3,6 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Diagnostics;
using System.Linq;
using Xunit;
......@@ -40,7 +39,11 @@ public GenericUriParser(GenericUriParserOptions options)
namespace Dapper.Tests
{
public class MiscTests : TestBase
public sealed class SystemSqlClientMiscTests : MiscTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientMiscTests : MiscTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class MiscTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestNullableGuidSupport()
......@@ -1026,7 +1029,9 @@ public void Issue178_SqlServer()
try { connection.Execute("create table Issue178(id int not null)"); }
catch { /* don't care */ }
// raw ADO.net
var sqlCmd = new SqlCommand(sql, connection);
using (var sqlCmd = connection.CreateCommand())
{
sqlCmd.CommandText = sql;
using (IDataReader reader1 = sqlCmd.ExecuteReader())
{
Assert.True(reader1.Read());
......@@ -1034,6 +1039,7 @@ public void Issue178_SqlServer()
Assert.False(reader1.Read());
Assert.False(reader1.NextResult());
}
}
// dapper
using (var reader2 = connection.ExecuteReader(sql))
......
......@@ -6,7 +6,11 @@
namespace Dapper.Tests
{
public class MultiMapTests : TestBase
public sealed class SystemSqlClientMultiMapTests : MultiMapTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientMultiMapTests : MultiMapTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class MultiMapTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void ParentChildIdentityAssociations()
......
......@@ -3,7 +3,13 @@
namespace Dapper.Tests
{
[Collection(NonParallelDefinition.Name)]
public class NullTests : TestBase
public sealed class SystemSqlClientNullTests : NullTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientNullTests : NullTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class NullTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestNullableDefault()
......
......@@ -3,7 +3,6 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Data.SqlClient;
using System.Data.SqlTypes;
using System.Dynamic;
using System.Linq;
......@@ -19,7 +18,13 @@
namespace Dapper.Tests
{
public class ParameterTests : TestBase
[Collection(NonParallelDefinition.Name)] // because it creates SQL types that compete between the two providers
public sealed class SystemSqlClientParameterTests : ParameterTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)] // because it creates SQL types that compete between the two providers
public sealed class MicrosoftSqlClientParameterTests : ParameterTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ParameterTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
public class DbParams : SqlMapper.IDynamicParameters, IEnumerable<IDbDataParameter>
{
......@@ -37,7 +42,7 @@ void SqlMapper.IDynamicParameters.AddParameters(IDbCommand command, SqlMapper.Id
command.Parameters.Add(parameter);
}
}
/* problems with conflicting type
private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecordList(IEnumerable<int> numbers)
{
var number_list = new List<Microsoft.SqlServer.Server.SqlDataRecord>();
......@@ -56,6 +61,7 @@ private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecor
return number_list;
}
private class IntDynamicParam : SqlMapper.IDynamicParameters
{
private readonly IEnumerable<int> numbers;
......@@ -66,7 +72,7 @@ public IntDynamicParam(IEnumerable<int> numbers)
public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
var sqlCommand = (SqlCommand)command;
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
......@@ -89,7 +95,7 @@ public IntCustomParam(IEnumerable<int> numbers)
public void AddParameter(IDbCommand command, string name)
{
var sqlCommand = (SqlCommand)command;
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
......@@ -101,6 +107,7 @@ public void AddParameter(IDbCommand command, string name)
p.Value = number_list;
}
}
*/
/* TODO:
*
......@@ -214,6 +221,9 @@ public void TestMassiveStrings()
Assert.Equal(connection.Query<string>("select @a", new { a = str }).First(), str);
}
/* problems with conflicting type
*
[Fact]
public void TestTVPWithAnonymousObject()
{
......@@ -312,7 +322,7 @@ public new void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
base.AddParameters(command, identity);
var sqlCommand = (SqlCommand)command;
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
......@@ -462,6 +472,8 @@ public void TestSqlDataRecordListParametersWithTypeHandlers()
}
}
*/
#if !NETCOREAPP1_0
[Fact]
public void DataTableParameters()
......@@ -612,14 +624,19 @@ public SO29596645_RuleTableValuedParameters(string parameterName)
public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
Debug.WriteLine("> AddParameters");
var lazy = (SqlCommand)command;
lazy.Parameters.AddWithValue("Id", 7);
var p = command.CreateParameter();
p.ParameterName = "Id";
p.Value = 7;
command.Parameters.Add(p);
var table = new DataTable
{
Columns = { { "Id", typeof(int) } },
Rows = { { 4 }, { 9 } }
};
lazy.Parameters.AddWithValue("Rules", table);
p = command.CreateParameter();
p.ParameterName = "Rules";
p.Value = table;
command.Parameters.Add(p);
Debug.WriteLine("< AddParameters");
}
}
......@@ -733,8 +750,8 @@ public class HazSqlHierarchy
public void TestCustomParameters()
{
var args = new DbParams {
new SqlParameter("foo", 123),
new SqlParameter("bar", "abc")
Provider.CreateRawParameter("foo", 123),
Provider.CreateRawParameter("bar", "abc")
};
var result = connection.Query("select Foo=@foo, Bar=@bar", args).Single();
int foo = result.Foo;
......
......@@ -6,7 +6,11 @@
namespace Dapper.Tests
{
public class ProcedureTests : TestBase
public sealed class SystemSqlClientProcedureTests : ProcedureTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientProcedureTests : ProcedureTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ProcedureTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestProcWithOutParameter()
......
......@@ -6,8 +6,13 @@
namespace Dapper.Tests.Providers
{
public sealed class SystemSqlClientEntityFrameworkTests : EntityFrameworkTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientEntityFrameworkTests : EntityFrameworkTests<MicrosoftSqlClientProvider> { }
#endif
[Collection("TypeHandlerTests")]
public class EntityFrameworkTests : TestBase
public abstract class EntityFrameworkTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
public EntityFrameworkTests()
{
......
using FirebirdSql.Data.FirebirdClient;
using System.Data;
using System.Data.Common;
using System.Linq;
using Xunit;
namespace Dapper.Tests.Providers
{
public class FirebirdTests : TestBase
public class FirebirdProvider : DatabaseProvider
{
public override DbProviderFactory Factory => FirebirdClientFactory.Instance;
public override string GetConnectionString() => "initial catalog=localhost:database;user id=SYSDBA;password=masterkey";
}
public class FirebirdTests : TestBase<FirebirdProvider>
{
private FbConnection GetOpenFirebirdConnection() => (FbConnection)Provider.GetOpenConnection();
[Fact(Skip = "Bug in Firebird; a PR to fix it has been submitted")]
public void Issue178_Firebird()
{
const string cs = "initial catalog=localhost:database;user id=SYSDBA;password=masterkey";
using (var connection = new FbConnection(cs))
using (var connection = GetOpenFirebirdConnection())
{
connection.Open();
const string sql = "select count(*) from Issue178";
try { connection.Execute("drop table Issue178"); }
catch { /* don't care */ }
......
......@@ -8,7 +8,11 @@
namespace Dapper.Tests
{
public class Linq2SqlTests : TestBase
public sealed class SystemSqlClientLinq2SqlTests : Linq2SqlTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientLinq2SqlTests : Linq2SqlTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class Linq2SqlTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestLinqBinaryToClass()
......
using System;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
namespace Dapper.Tests
{
public class MySQLTests : TestBase
public sealed class MySqlProvider : DatabaseProvider
{
private static MySql.Data.MySqlClient.MySqlConnection GetMySqlConnection(bool open = true,
bool convertZeroDatetime = false, bool allowZeroDatetime = false)
{
string cs = IsAppVeyor
public override DbProviderFactory Factory => MySql.Data.MySqlClient.MySqlClientFactory.Instance;
public override string GetConnectionString() => IsAppVeyor
? "Server=localhost;Database=test;Uid=root;Pwd=Password12!;"
: "Server=localhost;Database=tests;Uid=test;Pwd=pass;";
var csb = new MySql.Data.MySqlClient.MySqlConnectionStringBuilder(cs)
public DbConnection GetMySqlConnection(bool open = true,
bool convertZeroDatetime = false, bool allowZeroDatetime = false)
{
AllowZeroDateTime = allowZeroDatetime,
ConvertZeroDateTime = convertZeroDatetime
};
var conn = new MySql.Data.MySqlClient.MySqlConnection(csb.ConnectionString);
string cs = GetConnectionString();
var csb = Factory.CreateConnectionStringBuilder();
csb.ConnectionString = cs;
((dynamic)csb).AllowZeroDateTime = allowZeroDatetime;
((dynamic)csb).ConvertZeroDateTime = convertZeroDatetime;
var conn = Factory.CreateConnection();
conn.ConnectionString = csb.ConnectionString;
if (open) conn.Open();
return conn;
}
}
public class MySQLTests : TestBase<MySqlProvider>
{
[FactMySql]
public void DapperEnumValue_Mysql()
{
using (var conn = GetMySqlConnection())
using (var conn = Provider.GetMySqlConnection())
{
Common.DapperEnumValue(conn);
}
......@@ -34,7 +42,7 @@ public void DapperEnumValue_Mysql()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/552, not resolved on the MySQL end.")]
public void Issue552_SignedUnsignedBooleans()
{
using (var conn = GetMySqlConnection(true, false, false))
using (var conn = Provider.GetMySqlConnection(true, false, false))
{
conn.Execute(@"
CREATE TEMPORARY TABLE IF NOT EXISTS `bar` (
......@@ -74,7 +82,7 @@ private class MySqlHasBool
[FactMySql]
public void Issue295_NullableDateTime_MySql_Default()
{
using (var conn = GetMySqlConnection(true, false, false))
using (var conn = Provider.GetMySqlConnection(true, false, false))
{
Common.TestDateTime(conn);
}
......@@ -83,7 +91,7 @@ public void Issue295_NullableDateTime_MySql_Default()
[FactMySql]
public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime()
{
using (var conn = GetMySqlConnection(true, true, false))
using (var conn = Provider.GetMySqlConnection(true, true, false))
{
Common.TestDateTime(conn);
}
......@@ -92,7 +100,7 @@ public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")]
public void Issue295_NullableDateTime_MySql_AllowZeroDatetime()
{
using (var conn = GetMySqlConnection(true, false, true))
using (var conn = Provider.GetMySqlConnection(true, false, true))
{
Common.TestDateTime(conn);
}
......@@ -101,7 +109,7 @@ public void Issue295_NullableDateTime_MySql_AllowZeroDatetime()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")]
public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime()
{
using (var conn = GetMySqlConnection(true, true, true))
using (var conn = Provider.GetMySqlConnection(true, true, true))
{
Common.TestDateTime(conn);
}
......@@ -110,7 +118,7 @@ public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime()
[FactMySql]
public void Issue426_SO34439033_DateTimeGainsTicks()
{
using (var conn = GetMySqlConnection(true, true, true))
using (var conn = Provider.GetMySqlConnection(true, true, true))
{
try { conn.Execute("drop table Issue426_Test"); } catch { /* don't care */ }
try { conn.Execute("create table Issue426_Test (Id int not null, Time time not null)"); } catch { /* don't care */ }
......@@ -133,7 +141,7 @@ public void Issue426_SO34439033_DateTimeGainsTicks()
[FactMySql]
public void SO36303462_Tinyint_Bools()
{
using (var conn = GetMySqlConnection(true, true, true))
using (var conn = Provider.GetMySqlConnection(true, true, true))
{
try { conn.Execute("drop table SO36303462_Test"); } catch { /* don't care */ }
conn.Execute("create table SO36303462_Test (Id int not null, IsBold tinyint not null);");
......@@ -149,6 +157,52 @@ public void SO36303462_Tinyint_Bools()
}
}
[FactMySql]
public void Issue1277_ReaderSync()
{
using (var conn = Provider.GetMySqlConnection())
{
try { conn.Execute("drop table Issue1277_Test"); } catch { /* don't care */ }
conn.Execute("create table Issue1277_Test (Id int not null, IsBold tinyint not null);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (1,1);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (2,0);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (3,1);");
using (var reader = conn.ExecuteReader(
"select * from Issue1277_Test where Id < @id",
new { id = 42 }))
{
var table = new DataTable();
table.Load(reader);
Assert.Equal(2, table.Columns.Count);
Assert.Equal(3, table.Rows.Count);
}
}
}
[FactMySql]
public async Task Issue1277_ReaderAsync()
{
using (var conn = Provider.GetMySqlConnection())
{
try { await conn.ExecuteAsync("drop table Issue1277_Test"); } catch { /* don't care */ }
await conn.ExecuteAsync("create table Issue1277_Test (Id int not null, IsBold tinyint not null);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (1,1);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (2,0);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (3,1);");
using (var reader = await conn.ExecuteReaderAsync(
"select * from Issue1277_Test where Id < @id",
new { id = 42 }))
{
var table = new DataTable();
table.Load(reader);
Assert.Equal(2, table.Columns.Count);
Assert.Equal(3, table.Rows.Count);
}
}
}
private class SO36303462
{
public int Id { get; set; }
......@@ -176,7 +230,7 @@ static FactMySqlAttribute()
{
try
{
using (GetMySqlConnection(true)) { /* just trying to see if it works */ }
using (DatabaseProvider<MySqlProvider>.Instance.GetMySqlConnection(true)) { /* just trying to see if it works */ }
}
catch (Exception ex)
{
......
#if OLEDB
using System;
using System.Data.Common;
using System.Data.OleDb;
using System.Linq;
using Xunit;
namespace Dapper.Tests
{
public class OLDEBTests : TestBase
public class OLEDBProvider : DatabaseProvider
{
public static string OleDbConnectionString =>
public override DbProviderFactory Factory => OleDbFactory.Instance;
public override string GetConnectionString() =>
IsAppVeyor
? @"Provider=SQLOLEDB;Data Source=(local)\SQL2016;Initial Catalog=tempdb;User Id=sa;Password=Password12!"
: "Provider=SQLOLEDB;Data Source=.;Initial Catalog=tempdb;Integrated Security=SSPI";
}
public OleDbConnection GetOleDbConnection()
public class OLDEBTests : TestBase<OLEDBProvider>
{
var conn = new OleDbConnection(OleDbConnectionString);
conn.Open();
return conn;
}
public OleDbConnection GetOleDbConnection() => (OleDbConnection) Provider.GetOpenConnection();
// see https://stackoverflow.com/q/18847510/23354
[Fact]
......
using System;
using System.Data;
using System.Data.Common;
using System.Linq;
using Xunit;
namespace Dapper.Tests
{
public class PostgresqlTests : TestBase
public class PostgresProvider : DatabaseProvider
{
private static Npgsql.NpgsqlConnection GetOpenNpgsqlConnection()
{
string cs = IsAppVeyor
public override DbProviderFactory Factory => Npgsql.NpgsqlFactory.Instance;
public override string GetConnectionString() => IsAppVeyor
? "Server=localhost;Port=5432;User Id=postgres;Password=Password12!;Database=test"
: "Server=localhost;Port=5432;User Id=dappertest;Password=dapperpass;Database=dappertest"; // ;Encoding = UNICODE
var conn = new Npgsql.NpgsqlConnection(cs);
conn.Open();
return conn;
}
public class PostgresqlTests : TestBase<PostgresProvider>
{
private Npgsql.NpgsqlConnection GetOpenNpgsqlConnection() => (Npgsql.NpgsqlConnection)Provider.GetOpenConnection();
private class Cat
{
......@@ -71,7 +71,7 @@ static FactPostgresqlAttribute()
{
try
{
using (GetOpenNpgsqlConnection()) { /* just trying to see if it works */ }
using (DatabaseProvider<PostgresProvider>.Instance.GetOpenConnection()) { /* just trying to see if it works */ }
}
catch (Exception ex)
{
......
using Microsoft.Data.Sqlite;
using System;
using System.Data.Common;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
namespace Dapper.Tests
{
public class SqliteTests : TestBase
public class SqliteProvider : DatabaseProvider
{
protected static SqliteConnection GetSQLiteConnection(bool open = true)
public override DbProviderFactory Factory => SqliteFactory.Instance;
public override string GetConnectionString() => "Data Source=:memory:";
}
public abstract class SqliteTypeTestBase : TestBase<SqliteProvider>
{
protected SqliteConnection GetSQLiteConnection(bool open = true)
=> (SqliteConnection)(open ? Provider.GetOpenConnection() : Provider.GetClosedConnection());
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class FactSqliteAttribute : FactAttribute
{
public override string Skip
{
var connection = new SqliteConnection("Data Source=:memory:");
if (open) connection.Open();
return connection;
get { return unavailable ?? base.Skip; }
set { base.Skip = value; }
}
[FactSqlite]
public void DapperEnumValue_Sqlite()
private static readonly string unavailable;
static FactSqliteAttribute()
{
using (var connection = GetSQLiteConnection())
try
{
Common.DapperEnumValue(connection);
using (DatabaseProvider<SqliteProvider>.Instance.GetOpenConnection())
{
}
}
catch (Exception ex)
{
unavailable = $"Sqlite is unavailable: {ex.Message}";
}
}
}
}
[Collection(NonParallelDefinition.Name)]
public class SqliteTypeHandlerTests : TestBase
public class SqliteTypeHandlerTests : SqliteTypeTestBase
{
[FactSqlite]
public void Issue466_SqliteHatesOptimizations()
......@@ -66,6 +87,19 @@ public async Task Issue466_SqliteHatesOptimizations_Async()
}
}
public class SqliteTests : SqliteTypeTestBase
{
[FactSqlite]
public void DapperEnumValue_Sqlite()
{
using (var connection = GetSQLiteConnection())
{
Common.DapperEnumValue(connection);
}
}
[FactSqlite]
public void Isse467_SqliteLikesParametersWithPrefix()
{
......@@ -90,31 +124,5 @@ private void Isse467_SqliteParameterNaming(bool prefix)
Assert.Equal(42, i);
}
}
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class FactSqliteAttribute : FactAttribute
{
public override string Skip
{
get { return unavailable ?? base.Skip; }
set { base.Skip = value; }
}
private static readonly string unavailable;
static FactSqliteAttribute()
{
try
{
using (GetSQLiteConnection())
{
}
}
catch (Exception ex)
{
unavailable = $"Sqlite is unavailable: {ex.Message}";
}
}
}
}
}
......@@ -6,7 +6,11 @@
namespace Dapper.Tests
{
public class QueryMultipleTests : TestBase
public sealed class SystemSqlClientQueryMultipleTests : QueryMultipleTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientQueryMultipleTests : QueryMultipleTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class QueryMultipleTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestQueryMultipleBuffered()
......
using System;
using System.Data;
using System.Data.SqlClient;
using System.Globalization;
using Xunit;
using System.Data.Common;
#if !NETCOREAPP1_0
using System.Threading;
#endif
namespace Dapper.Tests
{
public abstract class TestBase : IDisposable
public static class DatabaseProvider<TProvider> where TProvider : DatabaseProvider
{
protected static readonly bool IsAppVeyor = Environment.GetEnvironmentVariable("Appveyor")?.ToUpperInvariant() == "TRUE";
public static string ConnectionString =>
IsAppVeyor
? @"Server=(local)\SQL2016;Database=tempdb;User ID=sa;Password=Password12!"
: "Data Source=.;Initial Catalog=tempdb;Integrated Security=True";
public static TProvider Instance { get; } = Activator.CreateInstance<TProvider>();
}
public abstract class DatabaseProvider
{
public abstract DbProviderFactory Factory { get; }
protected SqlConnection _connection;
protected SqlConnection connection => _connection ?? (_connection = GetOpenConnection());
public static bool IsAppVeyor { get; } = Environment.GetEnvironmentVariable("Appveyor")?.ToUpperInvariant() == "TRUE";
public virtual void Dispose() { }
public abstract string GetConnectionString();
public static SqlConnection GetOpenConnection(bool mars = false)
public DbConnection GetOpenConnection()
{
var cs = ConnectionString;
if (mars)
var conn = Factory.CreateConnection();
conn.ConnectionString = GetConnectionString();
conn.Open();
if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
return conn;
}
public DbConnection GetClosedConnection()
{
var scsb = new SqlConnectionStringBuilder(cs)
var conn = Factory.CreateConnection();
conn.ConnectionString = GetConnectionString();
if (conn.State != ConnectionState.Closed) throw new InvalidOperationException("should be closed!");
return conn;
}
public DbParameter CreateRawParameter(string name, object value)
{
MultipleActiveResultSets = true
};
cs = scsb.ConnectionString;
var p = Factory.CreateParameter();
p.ParameterName = name;
p.Value = value ?? DBNull.Value;
return p;
}
var connection = new SqlConnection(cs);
connection.Open();
return connection;
}
public SqlConnection GetClosedConnection()
public abstract class SqlServerDatabaseProvider : DatabaseProvider
{
var conn = new SqlConnection(ConnectionString);
if (conn.State != ConnectionState.Closed) throw new InvalidOperationException("should be closed!");
public override string GetConnectionString() =>
IsAppVeyor
? @"Server=(local)\SQL2016;Database=tempdb;User ID=sa;Password=Password12!"
: "Data Source=.;Initial Catalog=tempdb;Integrated Security=True";
public DbConnection GetOpenConnection(bool mars)
{
if (!mars) return GetOpenConnection();
var scsb = Factory.CreateConnectionStringBuilder();
scsb.ConnectionString = GetConnectionString();
((dynamic)scsb).MultipleActiveResultSets = true;
var conn = Factory.CreateConnection();
conn.ConnectionString = scsb.ConnectionString;
conn.Open();
if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
return conn;
}
}
public sealed class SystemSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => System.Data.SqlClient.SqlClientFactory.Instance;
}
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => Microsoft.Data.SqlClient.SqlClientFactory.Instance;
}
#endif
public abstract class TestBase<TProvider> : IDisposable where TProvider : DatabaseProvider
{
protected DbConnection GetOpenConnection() => Provider.GetOpenConnection();
protected DbConnection GetClosedConnection() => Provider.GetClosedConnection();
protected DbConnection _connection;
protected DbConnection connection => _connection ?? (_connection = Provider.GetOpenConnection());
public TProvider Provider { get; } = DatabaseProvider<TProvider>.Instance;
protected static CultureInfo ActiveCulture
{
......@@ -58,7 +102,10 @@ protected static CultureInfo ActiveCulture
static TestBase()
{
Console.WriteLine("Dapper: " + typeof(SqlMapper).AssemblyQualifiedName);
Console.WriteLine("Using Connectionstring: {0}", ConnectionString);
var provider = DatabaseProvider<TProvider>.Instance;
Console.WriteLine("Using Connectionstring: {0}", provider.GetConnectionString());
var factory = provider.Factory;
Console.WriteLine("Using Provider: {0}", factory.GetType().FullName);
#if NETCOREAPP1_0
Console.WriteLine("CoreCLR (netcoreapp1.0)");
#else
......@@ -77,14 +124,15 @@ static TestBase()
#endif
}
public void Dispose()
public virtual void Dispose()
{
_connection?.Dispose();
_connection = null;
Provider?.Dispose();
}
}
[CollectionDefinition(Name, DisableParallelization = true)]
public class NonParallelDefinition : TestBase
public static class NonParallelDefinition
{
public const string Name = "NonParallel";
}
......
......@@ -7,7 +7,11 @@
namespace Dapper.Tests
{
public class TransactionTests : TestBase
public sealed class SystemSqlClientTransactionTests : TransactionTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientTransactionTests : TransactionTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TransactionTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestTransactionCommit()
......
......@@ -3,7 +3,11 @@
namespace Dapper.Tests
{
public class TupleTests : TestBase
public sealed class SystemSqlClientTupleTests : TupleTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientTupleTests : TupleTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TupleTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TupleStructParameter_Fails_HelpfulMessage()
......
......@@ -9,7 +9,13 @@
namespace Dapper.Tests
{
[Collection(NonParallelDefinition.Name)]
public class TypeHandlerTests : TestBase
public sealed class SystemSqlClientTypeHandlerTests : TypeHandlerTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientTypeHandlerTests : TypeHandlerTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TypeHandlerTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void TestChangingDefaultStringTypeMappingToAnsiString()
......
......@@ -4,7 +4,11 @@
namespace Dapper.Tests
{
public class XmlTests : TestBase
public sealed class SystemSqlClientXmlTests : XmlTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientXmlTests : XmlTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class XmlTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{
[Fact]
public void CommonXmlTypesSupported()
......
......@@ -5,7 +5,7 @@
<Title>Dapper</Title>
<Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description>
<Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0</TargetFrameworks>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0;netcoreapp2.1</TargetFrameworks>
</PropertyGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net451'">
<Reference Include="System" />
......@@ -15,9 +15,15 @@
<Reference Include="Microsoft.CSharp" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" />
<!-- it would be nice to use System.Data.Common here, but we need SqlClient for SqlDbType in 1.3, and legacy SqlDataRecord API-->
<!--<PackageReference Include="System.Data.Common" Version="4.3.0" />-->
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' ">
<PackageReference Include="System.Collections.Concurrent" Version="4.3.0" />
......
......@@ -4,7 +4,10 @@
namespace Dapper
{
internal sealed class SqlDataRecordHandler : SqlMapper.ITypeHandler
internal sealed class SqlDataRecordHandler<T> : SqlMapper.ITypeHandler
#if !NETSTANDARD1_3
where T : IDataRecord
#endif
{
public object Parse(Type destinationType, object value)
{
......@@ -13,7 +16,7 @@ public object Parse(Type destinationType, object value)
public void SetValue(IDbDataParameter parameter, object value)
{
SqlDataRecordListTVPParameter.Set(parameter, value as IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>, null);
SqlDataRecordListTVPParameter<T>.Set(parameter, value as IEnumerable<T>, null);
}
}
}
using System.Collections.Generic;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
namespace Dapper
{
/// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a SqlDataRecordListTVPParameter
/// </summary>
internal sealed class SqlDataRecordListTVPParameter : SqlMapper.ICustomQueryParameter
internal sealed class SqlDataRecordListTVPParameter<T> : SqlMapper.ICustomQueryParameter
#if !NETSTANDARD1_3
where T : IDataRecord
#endif
{
private readonly IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data;
private readonly IEnumerable<T> data;
private readonly string typeName;
/// <summary>
/// Create a new instance of <see cref="SqlDataRecordListTVPParameter"/>.
/// Create a new instance of <see cref="SqlDataRecordListTVPParameter&lt;T&gt;"/>.
/// </summary>
/// <param name="data">The data records to convert into TVPs.</param>
/// <param name="typeName">The parameter type name.</param>
public SqlDataRecordListTVPParameter(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data, string typeName)
public SqlDataRecordListTVPParameter(IEnumerable<T> data, string typeName)
{
this.data = data;
this.typeName = typeName;
......@@ -30,14 +37,72 @@ void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string nam
command.Parameters.Add(param);
}
internal static void Set(IDbDataParameter parameter, IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data, string typeName)
internal static void Set(IDbDataParameter parameter, IEnumerable<T> data, string typeName)
{
parameter.Value = data != null && data.Any() ? data : null;
if (parameter is System.Data.SqlClient.SqlParameter sqlParam)
StructuredHelper.ConfigureTVP(parameter, typeName);
}
}
static class StructuredHelper
{
private static readonly Hashtable s_udt = new Hashtable(), s_tvp = new Hashtable();
private static Action<IDbDataParameter, string> GetUDT(Type type)
=> (Action<IDbDataParameter, string>)s_udt[type] ?? SlowGetHelper(type, s_udt, "UdtTypeName", 29); // 29 = SqlDbType.Udt (avoiding ref)
private static Action<IDbDataParameter, string> GetTVP(Type type)
=> (Action<IDbDataParameter, string>)s_tvp[type] ?? SlowGetHelper(type, s_tvp, "TypeName", 30); // 30 = SqlDbType.Structured (avoiding ref)
static Action<IDbDataParameter, string> SlowGetHelper(Type type, Hashtable hashtable, string nameProperty, int sqlDbType)
{
lock (hashtable)
{
sqlParam.SqlDbType = SqlDbType.Structured;
sqlParam.TypeName = typeName;
var helper = (Action<IDbDataParameter, string>)hashtable[type];
if (helper == null)
{
helper = CreateFor(type, nameProperty, sqlDbType);
hashtable.Add(type, helper);
}
return helper;
}
}
static Action<IDbDataParameter, string> CreateFor(Type type, string nameProperty, int sqlDbType)
{
var name = type.GetProperty(nameProperty, BindingFlags.Public | BindingFlags.Instance);
if (name == null || !name.CanWrite)
{
return (p, n) => { };
}
var dbType = type.GetProperty("SqlDbType", BindingFlags.Public | BindingFlags.Instance);
if (dbType != null && !dbType.CanWrite) dbType = null;
var dm = new DynamicMethod(nameof(CreateFor) + "_" + type.Name, null,
new[] { typeof(IDbDataParameter), typeof(string) }, true);
var il = dm.GetILGenerator();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Ldarg_1);
il.EmitCall(OpCodes.Callvirt, name.GetSetMethod(), null);
if (dbType != null)
{
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Ldc_I4, sqlDbType);
il.EmitCall(OpCodes.Callvirt, dbType.GetSetMethod(), null);
}
il.Emit(OpCodes.Ret);
return (Action<IDbDataParameter, string>)dm.CreateDelegate(typeof(Action<IDbDataParameter, string>));
}
// this needs to be done per-provider; "dynamic" doesn't work well on all runtimes, although that
// would be a fair option otherwise
internal static void ConfigureUDT(IDbDataParameter parameter, string typeName)
=> GetUDT(parameter.GetType())(parameter, typeName);
internal static void ConfigureTVP(IDbDataParameter parameter, string typeName)
=> GetTVP(parameter.GetType())(parameter, typeName);
}
}
......@@ -1101,7 +1101,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// </code>
/// </example>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) =>
ExecuteReaderImplAsync(cnn, new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered), CommandBehavior.Default);
ExecuteWrappedReaderImplAsync(cnn, new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered), CommandBehavior.Default);
/// <summary>
/// Execute parameterized SQL and return an <see cref="IDataReader"/>.
......@@ -1114,7 +1114,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// or <see cref="T:DataSet"/>.
/// </remarks>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command) =>
ExecuteReaderImplAsync(cnn, command, CommandBehavior.Default);
ExecuteWrappedReaderImplAsync(cnn, command, CommandBehavior.Default);
/// <summary>
/// Execute parameterized SQL and return an <see cref="IDataReader"/>.
......@@ -1128,26 +1128,27 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// or <see cref="T:DataSet"/>.
/// </remarks>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior) =>
ExecuteReaderImplAsync(cnn, command, commandBehavior);
ExecuteWrappedReaderImplAsync(cnn, command, commandBehavior);
private static async Task<IDataReader> ExecuteReaderImplAsync(IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior)
private static async Task<IDataReader> ExecuteWrappedReaderImplAsync(IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior)
{
Action<IDbCommand, object> paramReader = GetParameterReader(cnn, ref command);
DbCommand cmd = null;
bool wasClosed = cnn.State == ConnectionState.Closed;
bool wasClosed = cnn.State == ConnectionState.Closed, disposeCommand = true;
try
{
cmd = command.TrySetupAsyncCommand(cnn, paramReader);
if (wasClosed) await cnn.TryOpenAsync(command.CancellationToken).ConfigureAwait(false);
var reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, commandBehavior, command.CancellationToken).ConfigureAwait(false);
wasClosed = false;
return reader;
disposeCommand = false;
return WrappedReader.Create(cmd, reader);
}
finally
{
if (wasClosed) cnn.Close();
cmd?.Dispose();
if (cmd != null && disposeCommand) cmd.Dispose();
}
}
......
......@@ -237,7 +237,7 @@ private static void ResetTypeHandlers(bool clone)
[MethodImpl(MethodImplOptions.NoInlining)]
private static void AddSqlDataRecordsTypeHandler(bool clone)
{
AddTypeHandlerImpl(typeof(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>), new SqlDataRecordHandler(), clone);
AddTypeHandlerImpl(typeof(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>), new SqlDataRecordHandler<Microsoft.SqlServer.Server.SqlDataRecord>(), clone);
}
/// <summary>
......@@ -599,7 +599,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje
{
var command = new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered);
var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader);
return WrappedReader.Create(dbcmd, reader);
}
/// <summary>
......@@ -615,7 +615,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje
public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command)
{
var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader);
return WrappedReader.Create(dbcmd, reader);
}
/// <summary>
......@@ -632,7 +632,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinitio
public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior)
{
var reader = ExecuteReaderImpl(cnn, ref command, commandBehavior, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader);
return WrappedReader.Create(dbcmd, reader);
}
/// <summary>
......@@ -3785,13 +3785,24 @@ public static void SetTypeName(this DataTable table, string typeName)
table?.ExtendedProperties[DataTableTypeNameKey] as string;
#endif
/// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter.
/// </summary>
/// <param name="list">The list of records to convert to TVPs.</param>
/// <param name="typeName">The sql parameter type name.</param>
public static ICustomQueryParameter AsTableValuedParameter<T>(this IEnumerable<T> list, string typeName = null) where T : IDataRecord =>
new SqlDataRecordListTVPParameter<T>(list, typeName);
/*
/// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter.
/// </summary>
/// <param name="list">The list of records to convert to TVPs.</param>
/// <param name="typeName">The sql parameter type name.</param>
public static ICustomQueryParameter AsTableValuedParameter(this IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> list, string typeName = null) =>
new SqlDataRecordListTVPParameter(list, typeName);
new SqlDataRecordListTVPParameter<Microsoft.SqlServer.Server.SqlDataRecord>(list, typeName);
// ^^^ retained to avoid missing-method-exception; can presumably drop in a "major"
*/
// one per thread
[ThreadStatic]
......
using System;
using System.Data;
using System.Reflection;
using System.Data;
#if !NETSTANDARD1_3
namespace Dapper
......@@ -30,17 +28,6 @@ public TableValuedParameter(DataTable table, string typeName)
this.typeName = typeName;
}
private static readonly Action<System.Data.SqlClient.SqlParameter, string> setTypeName;
static TableValuedParameter()
{
var prop = typeof(System.Data.SqlClient.SqlParameter).GetProperty("TypeName", BindingFlags.Instance | BindingFlags.Public);
if (prop != null && prop.PropertyType == typeof(string) && prop.CanWrite)
{
setTypeName = (Action<System.Data.SqlClient.SqlParameter, string>)
Delegate.CreateDelegate(typeof(Action<System.Data.SqlClient.SqlParameter, string>), prop.GetSetMethod());
}
}
void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string name)
{
var param = command.CreateParameter();
......@@ -58,11 +45,7 @@ internal static void Set(IDbDataParameter parameter, DataTable table, string typ
{
typeName = table.GetTypeName();
}
if (!string.IsNullOrEmpty(typeName) && (parameter is System.Data.SqlClient.SqlParameter sqlParam))
{
setTypeName?.Invoke(sqlParam, typeName);
sqlParam.SqlDbType = SqlDbType.Structured;
}
if (!string.IsNullOrEmpty(typeName)) StructuredHelper.ConfigureTVP(parameter, typeName);
}
}
}
......
......@@ -33,11 +33,7 @@ void ITypeHandler.SetValue(IDbDataParameter parameter, object value)
#pragma warning disable 0618
parameter.Value = SanitizeParameterValue(value);
#pragma warning restore 0618
if (parameter is System.Data.SqlClient.SqlParameter && !(value is DBNull))
{
((System.Data.SqlClient.SqlParameter)parameter).SqlDbType = SqlDbType.Udt;
((System.Data.SqlClient.SqlParameter)parameter).UdtTypeName = udtTypeName;
}
if(!(value is DBNull)) StructuredHelper.ConfigureUDT(parameter, udtTypeName);
}
}
#endif
......
This diff is collapsed.
{
"sdk": {
"version": "2.2.203"
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment