Unverified Commit 9fbba78d authored by Marc Gravell's avatar Marc Gravell Committed by GitHub

Make Dapper happy with Microsoft.Data.SqlClient (#1262)

* prep work for adding Microsoft.Data.SqlClient support/tests; SqlClient dependency is not yet removed
parent 33d974a5
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
<Title>Dapper (Strong Named)</Title> <Title>Dapper (Strong Named)</Title>
<Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description> <Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description>
<Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors> <Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0</TargetFrameworks> <TargetFrameworks>net451;netstandard1.3;netstandard2.0;netcoreapp2.1</TargetFrameworks>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign> <PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
</PropertyGroup> </PropertyGroup>
...@@ -20,9 +20,15 @@ ...@@ -20,9 +20,15 @@
<Reference Include="Microsoft.CSharp" /> <Reference Include="Microsoft.CSharp" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'"> <ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" /> <PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" /> <PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" />
<!-- it would be nice to use System.Data.Common here, but we need SqlClient for SqlDbType in 1.3, and legacy SqlDataRecord API-->
<!--<PackageReference Include="System.Data.Common" Version="4.3.0" />-->
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' "> <ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' ">
<PackageReference Include="System.Collections.Concurrent" Version="4.3.0" /> <PackageReference Include="System.Collections.Concurrent" Version="4.3.0" />
......
...@@ -21,5 +21,6 @@ ...@@ -21,5 +21,6 @@
<PackageReference Include="MySqlConnector" Version="0.44.1" /> <PackageReference Include="MySqlConnector" Version="0.44.1" />
<PackageReference Include="xunit" Version="$(xUnitVersion)" /> <PackageReference Include="xunit" Version="$(xUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" /> <PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup> </ItemGroup>
</Project> </Project>
...@@ -4,15 +4,31 @@ ...@@ -4,15 +4,31 @@
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Threading; using System.Threading;
using System.Data.SqlClient;
using Xunit; using Xunit;
using System.Data.Common;
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class Tests : TestBase [Collection(NonParallelDefinition.Name)]
public sealed class SystemSqlClientAsyncTests : AsyncTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientAsyncTests : AsyncTests<MicrosoftSqlClientProvider> { }
#endif
[Collection(NonParallelDefinition.Name)]
public sealed class SystemSqlClientAsyncQueryCacheTests : AsyncQueryCacheTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientAsyncQueryCacheTests : AsyncQueryCacheTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class AsyncTests<TProvider> : TestBase<TProvider> where TProvider : SqlServerDatabaseProvider
{ {
private SqlConnection _marsConnection; private DbConnection _marsConnection;
private SqlConnection MarsConnection => _marsConnection ?? (_marsConnection = GetOpenConnection(true));
private DbConnection MarsConnection => _marsConnection ?? (_marsConnection = Provider.GetOpenConnection(true));
[Fact] [Fact]
public async Task TestBasicStringUsageAsync() public async Task TestBasicStringUsageAsync()
...@@ -100,7 +116,7 @@ public void TestLongOperationWithCancellation() ...@@ -100,7 +116,7 @@ public void TestLongOperationWithCancellation()
} }
catch (AggregateException agg) catch (AggregateException agg)
{ {
Assert.True(agg.InnerException is SqlException); Assert.True(agg.InnerException.GetType().Name == "SqlException");
} }
} }
...@@ -382,38 +398,6 @@ public void RunSequentialVersusParallelSync() ...@@ -382,38 +398,6 @@ public void RunSequentialVersusParallelSync()
Console.WriteLine("Pipeline: {0}ms", watch.ElapsedMilliseconds); Console.WriteLine("Pipeline: {0}ms", watch.ElapsedMilliseconds);
} }
[Collection(NonParallelDefinition.Name)]
public class AsyncQueryCacheTests : TestBase
{
private SqlConnection _marsConnection;
private SqlConnection MarsConnection => _marsConnection ?? (_marsConnection = GetOpenConnection(true));
[Fact]
public void AssertNoCacheWorksForQueryMultiple()
{
const int a = 123, b = 456;
var cmdDef = new CommandDefinition("select @a; select @b;", new
{
a,
b
}, commandType: CommandType.Text, flags: CommandFlags.NoCache);
int c, d;
SqlMapper.PurgeQueryCache();
int before = SqlMapper.GetCachedSQLCount();
using (var multi = MarsConnection.QueryMultiple(cmdDef))
{
c = multi.Read<int>().Single();
d = multi.Read<int>().Single();
}
int after = SqlMapper.GetCachedSQLCount();
Assert.Equal(0, before);
Assert.Equal(0, after);
Assert.Equal(123, c);
Assert.Equal(456, d);
}
}
private class BasicType private class BasicType
{ {
public string Value { get; set; } public string Value { get; set; }
...@@ -827,7 +811,46 @@ public async Task Issue563_QueryAsyncShouldThrowException() ...@@ -827,7 +811,46 @@ public async Task Issue563_QueryAsyncShouldThrowException()
var data = (await connection.QueryAsync<int>("select 1 union all select 2; RAISERROR('after select', 16, 1);").ConfigureAwait(false)).ToList(); var data = (await connection.QueryAsync<int>("select 1 union all select 2; RAISERROR('after select', 16, 1);").ConfigureAwait(false)).ToList();
Assert.True(false, "Expected Exception"); Assert.True(false, "Expected Exception");
} }
catch (SqlException ex) when (ex.Message == "after select") { /* swallow only this */ } catch (Exception ex) when (ex.GetType().Name == "SqlException" && ex.Message == "after select") { /* swallow only this */ }
}
}
[Collection(NonParallelDefinition.Name)]
public abstract class AsyncQueryCacheTests<TProvider> : TestBase<TProvider> where TProvider : SqlServerDatabaseProvider
{
private DbConnection _marsConnection;
private DbConnection MarsConnection => _marsConnection ?? (_marsConnection = Provider.GetOpenConnection(true));
public override void Dispose()
{
_marsConnection?.Dispose();
_marsConnection = null;
base.Dispose();
}
[Fact]
public void AssertNoCacheWorksForQueryMultiple()
{
const int a = 123, b = 456;
var cmdDef = new CommandDefinition("select @a; select @b;", new
{
a,
b
}, commandType: CommandType.Text, flags: CommandFlags.NoCache);
int c, d;
SqlMapper.PurgeQueryCache();
int before = SqlMapper.GetCachedSQLCount();
using (var multi = MarsConnection.QueryMultiple(cmdDef))
{
c = multi.Read<int>().Single();
d = multi.Read<int>().Single();
}
int after = SqlMapper.GetCachedSQLCount();
Assert.Equal(0, before);
Assert.Equal(0, after);
Assert.Equal(123, c);
Assert.Equal(456, d);
} }
} }
} }
...@@ -5,7 +5,12 @@ ...@@ -5,7 +5,12 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class ConstructorTests : TestBase public sealed class SystemSqlClientConstructorTests : ConstructorTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientConstructorTests : ConstructorTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ConstructorTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestAbstractInheritance() public void TestAbstractInheritance()
......
...@@ -6,16 +6,24 @@ ...@@ -6,16 +6,24 @@
<GenerateDocumentationFile>false</GenerateDocumentationFile> <GenerateDocumentationFile>false</GenerateDocumentationFile>
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects> <AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType> <GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
<TargetFrameworks>net452;netcoreapp1.0;netcoreapp2.0</TargetFrameworks> <TargetFrameworks>netcoreapp2.1;net46;netcoreapp2.0;net472</TargetFrameworks>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors> <TreatWarningsAsErrors>false</TreatWarningsAsErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition=" '$(TargetFramework)' == 'net452' "> <PropertyGroup Condition="'$(TargetFramework)' == 'net46' OR '$(TargetFramework)' == 'net472'">
<DefineConstants>$(DefineConstants);ENTITY_FRAMEWORK;LINQ2SQL;OLEDB</DefineConstants> <DefineConstants>$(DefineConstants);ENTITY_FRAMEWORK;LINQ2SQL;OLEDB</DefineConstants>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<None Remove="Test.DB.sdf" /> <None Remove="Test.DB.sdf" />
</ItemGroup> </ItemGroup>
<!-- these two go together
<PropertyGroup>
<DefineConstants>$(DefineConstants);MSSQLCLIENT</DefineConstants>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" Version="1.0.19128.1-Preview" />
</ItemGroup>
-->
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\Dapper\Dapper.csproj" /> <ProjectReference Include="..\Dapper\Dapper.csproj" />
<ProjectReference Include="..\Dapper.Contrib\Dapper.Contrib.csproj" /> <ProjectReference Include="..\Dapper.Contrib\Dapper.Contrib.csproj" />
...@@ -26,9 +34,10 @@ ...@@ -26,9 +34,10 @@
<PackageReference Include="System.ValueTuple" Version="4.4.0" /> <PackageReference Include="System.ValueTuple" Version="4.4.0" />
<PackageReference Include="xunit" Version="$(xUnitVersion)" /> <PackageReference Include="xunit" Version="$(xUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" /> <PackageReference Include="xunit.runner.visualstudio" Version="$(xUnitVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net452'"> <ItemGroup Condition="'$(TargetFramework)' == 'net46' OR '$(TargetFramework)' == 'net472'">
<ProjectReference Include="..\Dapper.EntityFramework\Dapper.EntityFramework.csproj" /> <ProjectReference Include="..\Dapper.EntityFramework\Dapper.EntityFramework.csproj" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="1.1.1" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="1.1.1" />
<PackageReference Include="Microsoft.SqlServer.Types" Version="14.0.314.76" /> <PackageReference Include="Microsoft.SqlServer.Types" Version="14.0.314.76" />
...@@ -48,6 +57,9 @@ ...@@ -48,6 +57,9 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.0'"> <ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.0'">
<PackageReference Include="Microsoft.Data.Sqlite" Version="2.0.0" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="2.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="Microsoft.Data.Sqlite" Version="2.0.0" />
</ItemGroup>
<PropertyGroup> <PropertyGroup>
<PostBuildEvent> <PostBuildEvent>
......
...@@ -4,7 +4,12 @@ ...@@ -4,7 +4,12 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class DataReaderTests : TestBase public sealed class SystemSqlClientDataReaderTests : DataReaderTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientDataReaderTests : DataReaderTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class DataReaderTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void GetSameReaderForSameShape() public void GetSameReaderForSameShape()
......
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class DecimalTests : TestBase public sealed class SystemSqlClientDecimalTests : DecimalTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientDecimalTests : DecimalTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class DecimalTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void Issue261_Decimals() public void Issue261_Decimals()
......
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class EnumTests : TestBase public sealed class SystemSqlClientEnumTests : EnumTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientEnumTests : EnumTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class EnumTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestEnumWeirdness() public void TestEnumWeirdness()
......
using System; using System;
using System.Data.SqlClient;
using Xunit; using Xunit;
namespace Dapper.Tests namespace Dapper.Tests
...@@ -32,7 +31,7 @@ public FactRequiredCompatibilityLevelAttribute(int level) : base() ...@@ -32,7 +31,7 @@ public FactRequiredCompatibilityLevelAttribute(int level) : base()
public static readonly int DetectedLevel; public static readonly int DetectedLevel;
static FactRequiredCompatibilityLevelAttribute() static FactRequiredCompatibilityLevelAttribute()
{ {
using (var conn = TestBase.GetOpenConnection()) using (var conn = DatabaseProvider<SystemSqlClientProvider>.Instance.GetOpenConnection())
{ {
try try
{ {
...@@ -57,15 +56,16 @@ public FactUnlessCaseSensitiveDatabaseAttribute() : base() ...@@ -57,15 +56,16 @@ public FactUnlessCaseSensitiveDatabaseAttribute() : base()
public static readonly bool IsCaseSensitive; public static readonly bool IsCaseSensitive;
static FactUnlessCaseSensitiveDatabaseAttribute() static FactUnlessCaseSensitiveDatabaseAttribute()
{ {
using (var conn = TestBase.GetOpenConnection()) using (var conn = DatabaseProvider<SystemSqlClientProvider>.Instance.GetOpenConnection())
{ {
try try
{ {
conn.Execute("declare @i int; set @I = 1;"); conn.Execute("declare @i int; set @I = 1;");
} }
catch (SqlException s) catch (Exception ex) when (ex.GetType().Name == "SqlException")
{ {
if (s.Number == 137) int err = ((dynamic)ex).Number;
if (err == 137)
IsCaseSensitive = true; IsCaseSensitive = true;
else else
throw; throw;
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class LiteralTests : TestBase public sealed class SystemSqlClientLiteralTests : LiteralTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientLiteralTests : LiteralTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class LiteralTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void LiteralReplacementEnumAndString() public void LiteralReplacementEnumAndString()
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data; using System.Data;
using System.Data.Common; using System.Data.Common;
using System.Data.SqlClient;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using Xunit; using Xunit;
...@@ -40,7 +39,11 @@ public GenericUriParser(GenericUriParserOptions options) ...@@ -40,7 +39,11 @@ public GenericUriParser(GenericUriParserOptions options)
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class MiscTests : TestBase public sealed class SystemSqlClientMiscTests : MiscTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientMiscTests : MiscTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class MiscTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestNullableGuidSupport() public void TestNullableGuidSupport()
...@@ -1026,13 +1029,16 @@ public void Issue178_SqlServer() ...@@ -1026,13 +1029,16 @@ public void Issue178_SqlServer()
try { connection.Execute("create table Issue178(id int not null)"); } try { connection.Execute("create table Issue178(id int not null)"); }
catch { /* don't care */ } catch { /* don't care */ }
// raw ADO.net // raw ADO.net
var sqlCmd = new SqlCommand(sql, connection); using (var sqlCmd = connection.CreateCommand())
using (IDataReader reader1 = sqlCmd.ExecuteReader())
{ {
Assert.True(reader1.Read()); sqlCmd.CommandText = sql;
Assert.Equal(0, reader1.GetInt32(0)); using (IDataReader reader1 = sqlCmd.ExecuteReader())
Assert.False(reader1.Read()); {
Assert.False(reader1.NextResult()); Assert.True(reader1.Read());
Assert.Equal(0, reader1.GetInt32(0));
Assert.False(reader1.Read());
Assert.False(reader1.NextResult());
}
} }
// dapper // dapper
......
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class MultiMapTests : TestBase public sealed class SystemSqlClientMultiMapTests : MultiMapTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientMultiMapTests : MultiMapTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class MultiMapTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void ParentChildIdentityAssociations() public void ParentChildIdentityAssociations()
......
...@@ -3,7 +3,13 @@ ...@@ -3,7 +3,13 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
[Collection(NonParallelDefinition.Name)] [Collection(NonParallelDefinition.Name)]
public class NullTests : TestBase public sealed class SystemSqlClientNullTests : NullTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientNullTests : NullTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class NullTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestNullableDefault() public void TestNullableDefault()
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Data; using System.Data;
using System.Data.SqlClient;
using System.Data.SqlTypes; using System.Data.SqlTypes;
using System.Dynamic; using System.Dynamic;
using System.Linq; using System.Linq;
...@@ -19,7 +18,13 @@ ...@@ -19,7 +18,13 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class ParameterTests : TestBase [Collection(NonParallelDefinition.Name)] // because it creates SQL types that compete between the two providers
public sealed class SystemSqlClientParameterTests : ParameterTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)] // because it creates SQL types that compete between the two providers
public sealed class MicrosoftSqlClientParameterTests : ParameterTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ParameterTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
public class DbParams : SqlMapper.IDynamicParameters, IEnumerable<IDbDataParameter> public class DbParams : SqlMapper.IDynamicParameters, IEnumerable<IDbDataParameter>
{ {
...@@ -37,7 +42,7 @@ void SqlMapper.IDynamicParameters.AddParameters(IDbCommand command, SqlMapper.Id ...@@ -37,7 +42,7 @@ void SqlMapper.IDynamicParameters.AddParameters(IDbCommand command, SqlMapper.Id
command.Parameters.Add(parameter); command.Parameters.Add(parameter);
} }
} }
/* problems with conflicting type
private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecordList(IEnumerable<int> numbers) private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecordList(IEnumerable<int> numbers)
{ {
var number_list = new List<Microsoft.SqlServer.Server.SqlDataRecord>(); var number_list = new List<Microsoft.SqlServer.Server.SqlDataRecord>();
...@@ -55,6 +60,7 @@ private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecor ...@@ -55,6 +60,7 @@ private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecor
return number_list; return number_list;
} }
private class IntDynamicParam : SqlMapper.IDynamicParameters private class IntDynamicParam : SqlMapper.IDynamicParameters
{ {
...@@ -66,7 +72,7 @@ public IntDynamicParam(IEnumerable<int> numbers) ...@@ -66,7 +72,7 @@ public IntDynamicParam(IEnumerable<int> numbers)
public void AddParameters(IDbCommand command, SqlMapper.Identity identity) public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{ {
var sqlCommand = (SqlCommand)command; var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure; sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers); var number_list = CreateSqlDataRecordList(numbers);
...@@ -78,7 +84,7 @@ public void AddParameters(IDbCommand command, SqlMapper.Identity identity) ...@@ -78,7 +84,7 @@ public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
p.Value = number_list; p.Value = number_list;
} }
} }
private class IntCustomParam : SqlMapper.ICustomQueryParameter private class IntCustomParam : SqlMapper.ICustomQueryParameter
{ {
private readonly IEnumerable<int> numbers; private readonly IEnumerable<int> numbers;
...@@ -89,7 +95,7 @@ public IntCustomParam(IEnumerable<int> numbers) ...@@ -89,7 +95,7 @@ public IntCustomParam(IEnumerable<int> numbers)
public void AddParameter(IDbCommand command, string name) public void AddParameter(IDbCommand command, string name)
{ {
var sqlCommand = (SqlCommand)command; var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure; sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers); var number_list = CreateSqlDataRecordList(numbers);
...@@ -101,6 +107,7 @@ public void AddParameter(IDbCommand command, string name) ...@@ -101,6 +107,7 @@ public void AddParameter(IDbCommand command, string name)
p.Value = number_list; p.Value = number_list;
} }
} }
*/
/* TODO: /* TODO:
* *
...@@ -214,6 +221,9 @@ public void TestMassiveStrings() ...@@ -214,6 +221,9 @@ public void TestMassiveStrings()
Assert.Equal(connection.Query<string>("select @a", new { a = str }).First(), str); Assert.Equal(connection.Query<string>("select @a", new { a = str }).First(), str);
} }
/* problems with conflicting type
*
[Fact] [Fact]
public void TestTVPWithAnonymousObject() public void TestTVPWithAnonymousObject()
{ {
...@@ -312,7 +322,7 @@ public new void AddParameters(IDbCommand command, SqlMapper.Identity identity) ...@@ -312,7 +322,7 @@ public new void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{ {
base.AddParameters(command, identity); base.AddParameters(command, identity);
var sqlCommand = (SqlCommand)command; var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure; sqlCommand.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers); var number_list = CreateSqlDataRecordList(numbers);
...@@ -462,6 +472,8 @@ public void TestSqlDataRecordListParametersWithTypeHandlers() ...@@ -462,6 +472,8 @@ public void TestSqlDataRecordListParametersWithTypeHandlers()
} }
} }
*/
#if !NETCOREAPP1_0 #if !NETCOREAPP1_0
[Fact] [Fact]
public void DataTableParameters() public void DataTableParameters()
...@@ -612,14 +624,19 @@ public SO29596645_RuleTableValuedParameters(string parameterName) ...@@ -612,14 +624,19 @@ public SO29596645_RuleTableValuedParameters(string parameterName)
public void AddParameters(IDbCommand command, SqlMapper.Identity identity) public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{ {
Debug.WriteLine("> AddParameters"); Debug.WriteLine("> AddParameters");
var lazy = (SqlCommand)command; var p = command.CreateParameter();
lazy.Parameters.AddWithValue("Id", 7); p.ParameterName = "Id";
p.Value = 7;
command.Parameters.Add(p);
var table = new DataTable var table = new DataTable
{ {
Columns = { { "Id", typeof(int) } }, Columns = { { "Id", typeof(int) } },
Rows = { { 4 }, { 9 } } Rows = { { 4 }, { 9 } }
}; };
lazy.Parameters.AddWithValue("Rules", table); p = command.CreateParameter();
p.ParameterName = "Rules";
p.Value = table;
command.Parameters.Add(p);
Debug.WriteLine("< AddParameters"); Debug.WriteLine("< AddParameters");
} }
} }
...@@ -733,8 +750,8 @@ public class HazSqlHierarchy ...@@ -733,8 +750,8 @@ public class HazSqlHierarchy
public void TestCustomParameters() public void TestCustomParameters()
{ {
var args = new DbParams { var args = new DbParams {
new SqlParameter("foo", 123), Provider.CreateRawParameter("foo", 123),
new SqlParameter("bar", "abc") Provider.CreateRawParameter("bar", "abc")
}; };
var result = connection.Query("select Foo=@foo, Bar=@bar", args).Single(); var result = connection.Query("select Foo=@foo, Bar=@bar", args).Single();
int foo = result.Foo; int foo = result.Foo;
......
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class ProcedureTests : TestBase public sealed class SystemSqlClientProcedureTests : ProcedureTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientProcedureTests : ProcedureTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class ProcedureTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestProcWithOutParameter() public void TestProcWithOutParameter()
......
...@@ -6,8 +6,13 @@ ...@@ -6,8 +6,13 @@
namespace Dapper.Tests.Providers namespace Dapper.Tests.Providers
{ {
public sealed class SystemSqlClientEntityFrameworkTests : EntityFrameworkTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientEntityFrameworkTests : EntityFrameworkTests<MicrosoftSqlClientProvider> { }
#endif
[Collection("TypeHandlerTests")] [Collection("TypeHandlerTests")]
public class EntityFrameworkTests : TestBase public abstract class EntityFrameworkTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
public EntityFrameworkTests() public EntityFrameworkTests()
{ {
......
using FirebirdSql.Data.FirebirdClient; using FirebirdSql.Data.FirebirdClient;
using System.Data; using System.Data;
using System.Data.Common;
using System.Linq; using System.Linq;
using Xunit; using Xunit;
namespace Dapper.Tests.Providers namespace Dapper.Tests.Providers
{ {
public class FirebirdTests : TestBase public class FirebirdProvider : DatabaseProvider
{ {
public override DbProviderFactory Factory => FirebirdClientFactory.Instance;
public override string GetConnectionString() => "initial catalog=localhost:database;user id=SYSDBA;password=masterkey";
}
public class FirebirdTests : TestBase<FirebirdProvider>
{
private FbConnection GetOpenFirebirdConnection() => (FbConnection)Provider.GetOpenConnection();
[Fact(Skip = "Bug in Firebird; a PR to fix it has been submitted")] [Fact(Skip = "Bug in Firebird; a PR to fix it has been submitted")]
public void Issue178_Firebird() public void Issue178_Firebird()
{ {
const string cs = "initial catalog=localhost:database;user id=SYSDBA;password=masterkey"; using (var connection = GetOpenFirebirdConnection())
using (var connection = new FbConnection(cs))
{ {
connection.Open();
const string sql = "select count(*) from Issue178"; const string sql = "select count(*) from Issue178";
try { connection.Execute("drop table Issue178"); } try { connection.Execute("drop table Issue178"); }
catch { /* don't care */ } catch { /* don't care */ }
......
...@@ -8,7 +8,11 @@ ...@@ -8,7 +8,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class Linq2SqlTests : TestBase public sealed class SystemSqlClientLinq2SqlTests : Linq2SqlTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientLinq2SqlTests : Linq2SqlTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class Linq2SqlTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestLinqBinaryToClass() public void TestLinqBinaryToClass()
......
using System; using System;
using System.Data;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading.Tasks;
using Xunit; using Xunit;
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class MySQLTests : TestBase public sealed class MySqlProvider : DatabaseProvider
{ {
private static MySql.Data.MySqlClient.MySqlConnection GetMySqlConnection(bool open = true, public override DbProviderFactory Factory => MySql.Data.MySqlClient.MySqlClientFactory.Instance;
bool convertZeroDatetime = false, bool allowZeroDatetime = false) public override string GetConnectionString() => IsAppVeyor
{
string cs = IsAppVeyor
? "Server=localhost;Database=test;Uid=root;Pwd=Password12!;" ? "Server=localhost;Database=test;Uid=root;Pwd=Password12!;"
: "Server=localhost;Database=tests;Uid=test;Pwd=pass;"; : "Server=localhost;Database=tests;Uid=test;Pwd=pass;";
var csb = new MySql.Data.MySqlClient.MySqlConnectionStringBuilder(cs)
{ public DbConnection GetMySqlConnection(bool open = true,
AllowZeroDateTime = allowZeroDatetime, bool convertZeroDatetime = false, bool allowZeroDatetime = false)
ConvertZeroDateTime = convertZeroDatetime {
}; string cs = GetConnectionString();
var conn = new MySql.Data.MySqlClient.MySqlConnection(csb.ConnectionString); var csb = Factory.CreateConnectionStringBuilder();
csb.ConnectionString = cs;
((dynamic)csb).AllowZeroDateTime = allowZeroDatetime;
((dynamic)csb).ConvertZeroDateTime = convertZeroDatetime;
var conn = Factory.CreateConnection();
conn.ConnectionString = csb.ConnectionString;
if (open) conn.Open(); if (open) conn.Open();
return conn; return conn;
} }
}
public class MySQLTests : TestBase<MySqlProvider>
{
[FactMySql] [FactMySql]
public void DapperEnumValue_Mysql() public void DapperEnumValue_Mysql()
{ {
using (var conn = GetMySqlConnection()) using (var conn = Provider.GetMySqlConnection())
{ {
Common.DapperEnumValue(conn); Common.DapperEnumValue(conn);
} }
...@@ -34,7 +42,7 @@ public void DapperEnumValue_Mysql() ...@@ -34,7 +42,7 @@ public void DapperEnumValue_Mysql()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/552, not resolved on the MySQL end.")] [FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/552, not resolved on the MySQL end.")]
public void Issue552_SignedUnsignedBooleans() public void Issue552_SignedUnsignedBooleans()
{ {
using (var conn = GetMySqlConnection(true, false, false)) using (var conn = Provider.GetMySqlConnection(true, false, false))
{ {
conn.Execute(@" conn.Execute(@"
CREATE TEMPORARY TABLE IF NOT EXISTS `bar` ( CREATE TEMPORARY TABLE IF NOT EXISTS `bar` (
...@@ -74,7 +82,7 @@ private class MySqlHasBool ...@@ -74,7 +82,7 @@ private class MySqlHasBool
[FactMySql] [FactMySql]
public void Issue295_NullableDateTime_MySql_Default() public void Issue295_NullableDateTime_MySql_Default()
{ {
using (var conn = GetMySqlConnection(true, false, false)) using (var conn = Provider.GetMySqlConnection(true, false, false))
{ {
Common.TestDateTime(conn); Common.TestDateTime(conn);
} }
...@@ -83,7 +91,7 @@ public void Issue295_NullableDateTime_MySql_Default() ...@@ -83,7 +91,7 @@ public void Issue295_NullableDateTime_MySql_Default()
[FactMySql] [FactMySql]
public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime() public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime()
{ {
using (var conn = GetMySqlConnection(true, true, false)) using (var conn = Provider.GetMySqlConnection(true, true, false))
{ {
Common.TestDateTime(conn); Common.TestDateTime(conn);
} }
...@@ -92,7 +100,7 @@ public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime() ...@@ -92,7 +100,7 @@ public void Issue295_NullableDateTime_MySql_ConvertZeroDatetime()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")] [FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")]
public void Issue295_NullableDateTime_MySql_AllowZeroDatetime() public void Issue295_NullableDateTime_MySql_AllowZeroDatetime()
{ {
using (var conn = GetMySqlConnection(true, false, true)) using (var conn = Provider.GetMySqlConnection(true, false, true))
{ {
Common.TestDateTime(conn); Common.TestDateTime(conn);
} }
...@@ -101,7 +109,7 @@ public void Issue295_NullableDateTime_MySql_AllowZeroDatetime() ...@@ -101,7 +109,7 @@ public void Issue295_NullableDateTime_MySql_AllowZeroDatetime()
[FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")] [FactMySql(Skip = "See https://github.com/StackExchange/Dapper/issues/295, AllowZeroDateTime=True is not supported")]
public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime() public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime()
{ {
using (var conn = GetMySqlConnection(true, true, true)) using (var conn = Provider.GetMySqlConnection(true, true, true))
{ {
Common.TestDateTime(conn); Common.TestDateTime(conn);
} }
...@@ -110,7 +118,7 @@ public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime() ...@@ -110,7 +118,7 @@ public void Issue295_NullableDateTime_MySql_ConvertAllowZeroDatetime()
[FactMySql] [FactMySql]
public void Issue426_SO34439033_DateTimeGainsTicks() public void Issue426_SO34439033_DateTimeGainsTicks()
{ {
using (var conn = GetMySqlConnection(true, true, true)) using (var conn = Provider.GetMySqlConnection(true, true, true))
{ {
try { conn.Execute("drop table Issue426_Test"); } catch { /* don't care */ } try { conn.Execute("drop table Issue426_Test"); } catch { /* don't care */ }
try { conn.Execute("create table Issue426_Test (Id int not null, Time time not null)"); } catch { /* don't care */ } try { conn.Execute("create table Issue426_Test (Id int not null, Time time not null)"); } catch { /* don't care */ }
...@@ -133,7 +141,7 @@ public void Issue426_SO34439033_DateTimeGainsTicks() ...@@ -133,7 +141,7 @@ public void Issue426_SO34439033_DateTimeGainsTicks()
[FactMySql] [FactMySql]
public void SO36303462_Tinyint_Bools() public void SO36303462_Tinyint_Bools()
{ {
using (var conn = GetMySqlConnection(true, true, true)) using (var conn = Provider.GetMySqlConnection(true, true, true))
{ {
try { conn.Execute("drop table SO36303462_Test"); } catch { /* don't care */ } try { conn.Execute("drop table SO36303462_Test"); } catch { /* don't care */ }
conn.Execute("create table SO36303462_Test (Id int not null, IsBold tinyint not null);"); conn.Execute("create table SO36303462_Test (Id int not null, IsBold tinyint not null);");
...@@ -149,6 +157,52 @@ public void SO36303462_Tinyint_Bools() ...@@ -149,6 +157,52 @@ public void SO36303462_Tinyint_Bools()
} }
} }
[FactMySql]
public void Issue1277_ReaderSync()
{
using (var conn = Provider.GetMySqlConnection())
{
try { conn.Execute("drop table Issue1277_Test"); } catch { /* don't care */ }
conn.Execute("create table Issue1277_Test (Id int not null, IsBold tinyint not null);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (1,1);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (2,0);");
conn.Execute("insert Issue1277_Test (Id, IsBold) values (3,1);");
using (var reader = conn.ExecuteReader(
"select * from Issue1277_Test where Id < @id",
new { id = 42 }))
{
var table = new DataTable();
table.Load(reader);
Assert.Equal(2, table.Columns.Count);
Assert.Equal(3, table.Rows.Count);
}
}
}
[FactMySql]
public async Task Issue1277_ReaderAsync()
{
using (var conn = Provider.GetMySqlConnection())
{
try { await conn.ExecuteAsync("drop table Issue1277_Test"); } catch { /* don't care */ }
await conn.ExecuteAsync("create table Issue1277_Test (Id int not null, IsBold tinyint not null);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (1,1);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (2,0);");
await conn.ExecuteAsync("insert Issue1277_Test (Id, IsBold) values (3,1);");
using (var reader = await conn.ExecuteReaderAsync(
"select * from Issue1277_Test where Id < @id",
new { id = 42 }))
{
var table = new DataTable();
table.Load(reader);
Assert.Equal(2, table.Columns.Count);
Assert.Equal(3, table.Rows.Count);
}
}
}
private class SO36303462 private class SO36303462
{ {
public int Id { get; set; } public int Id { get; set; }
...@@ -176,7 +230,7 @@ static FactMySqlAttribute() ...@@ -176,7 +230,7 @@ static FactMySqlAttribute()
{ {
try try
{ {
using (GetMySqlConnection(true)) { /* just trying to see if it works */ } using (DatabaseProvider<MySqlProvider>.Instance.GetMySqlConnection(true)) { /* just trying to see if it works */ }
} }
catch (Exception ex) catch (Exception ex)
{ {
......
#if OLEDB #if OLEDB
using System; using System;
using System.Data.Common;
using System.Data.OleDb; using System.Data.OleDb;
using System.Linq; using System.Linq;
using Xunit; using Xunit;
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class OLDEBTests : TestBase public class OLEDBProvider : DatabaseProvider
{ {
public static string OleDbConnectionString => public override DbProviderFactory Factory => OleDbFactory.Instance;
public override string GetConnectionString() =>
IsAppVeyor IsAppVeyor
? @"Provider=SQLOLEDB;Data Source=(local)\SQL2016;Initial Catalog=tempdb;User Id=sa;Password=Password12!" ? @"Provider=SQLOLEDB;Data Source=(local)\SQL2016;Initial Catalog=tempdb;User Id=sa;Password=Password12!"
: "Provider=SQLOLEDB;Data Source=.;Initial Catalog=tempdb;Integrated Security=SSPI"; : "Provider=SQLOLEDB;Data Source=.;Initial Catalog=tempdb;Integrated Security=SSPI";
}
public OleDbConnection GetOleDbConnection() public class OLDEBTests : TestBase<OLEDBProvider>
{ {
var conn = new OleDbConnection(OleDbConnectionString); public OleDbConnection GetOleDbConnection() => (OleDbConnection) Provider.GetOpenConnection();
conn.Open();
return conn;
}
// see https://stackoverflow.com/q/18847510/23354 // see https://stackoverflow.com/q/18847510/23354
[Fact] [Fact]
......
using System; using System;
using System.Data; using System.Data;
using System.Data.Common;
using System.Linq; using System.Linq;
using Xunit; using Xunit;
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class PostgresqlTests : TestBase public class PostgresProvider : DatabaseProvider
{ {
private static Npgsql.NpgsqlConnection GetOpenNpgsqlConnection() public override DbProviderFactory Factory => Npgsql.NpgsqlFactory.Instance;
{ public override string GetConnectionString() => IsAppVeyor
string cs = IsAppVeyor
? "Server=localhost;Port=5432;User Id=postgres;Password=Password12!;Database=test" ? "Server=localhost;Port=5432;User Id=postgres;Password=Password12!;Database=test"
: "Server=localhost;Port=5432;User Id=dappertest;Password=dapperpass;Database=dappertest"; // ;Encoding = UNICODE : "Server=localhost;Port=5432;User Id=dappertest;Password=dapperpass;Database=dappertest"; // ;Encoding = UNICODE
var conn = new Npgsql.NpgsqlConnection(cs); }
conn.Open(); public class PostgresqlTests : TestBase<PostgresProvider>
return conn; {
} private Npgsql.NpgsqlConnection GetOpenNpgsqlConnection() => (Npgsql.NpgsqlConnection)Provider.GetOpenConnection();
private class Cat private class Cat
{ {
...@@ -71,7 +71,7 @@ static FactPostgresqlAttribute() ...@@ -71,7 +71,7 @@ static FactPostgresqlAttribute()
{ {
try try
{ {
using (GetOpenNpgsqlConnection()) { /* just trying to see if it works */ } using (DatabaseProvider<PostgresProvider>.Instance.GetOpenConnection()) { /* just trying to see if it works */ }
} }
catch (Exception ex) catch (Exception ex)
{ {
......
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
using System; using System;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Xunit; using Xunit;
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class SqliteTests : TestBase public class SqliteProvider : DatabaseProvider
{ {
protected static SqliteConnection GetSQLiteConnection(bool open = true) public override DbProviderFactory Factory => SqliteFactory.Instance;
public override string GetConnectionString() => "Data Source=:memory:";
}
public abstract class SqliteTypeTestBase : TestBase<SqliteProvider>
{
protected SqliteConnection GetSQLiteConnection(bool open = true)
=> (SqliteConnection)(open ? Provider.GetOpenConnection() : Provider.GetClosedConnection());
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class FactSqliteAttribute : FactAttribute
{ {
var connection = new SqliteConnection("Data Source=:memory:"); public override string Skip
if (open) connection.Open(); {
return connection; get { return unavailable ?? base.Skip; }
set { base.Skip = value; }
}
private static readonly string unavailable;
static FactSqliteAttribute()
{
try
{
using (DatabaseProvider<SqliteProvider>.Instance.GetOpenConnection())
{
}
}
catch (Exception ex)
{
unavailable = $"Sqlite is unavailable: {ex.Message}";
}
}
} }
}
[Collection(NonParallelDefinition.Name)]
public class SqliteTypeHandlerTests : SqliteTypeTestBase
{
[FactSqlite] [FactSqlite]
public void DapperEnumValue_Sqlite() public void Issue466_SqliteHatesOptimizations()
{ {
using (var connection = GetSQLiteConnection()) using (var connection = GetSQLiteConnection())
{ {
Common.DapperEnumValue(connection); SqlMapper.ResetTypeHandlers();
var row = connection.Query<HazNameId>("select 42 as Id").First();
Assert.Equal(42, row.Id);
row = connection.Query<HazNameId>("select 42 as Id").First();
Assert.Equal(42, row.Id);
SqlMapper.ResetTypeHandlers();
row = connection.QueryFirst<HazNameId>("select 42 as Id");
Assert.Equal(42, row.Id);
row = connection.QueryFirst<HazNameId>("select 42 as Id");
Assert.Equal(42, row.Id);
} }
} }
[Collection(NonParallelDefinition.Name)] [FactSqlite]
public class SqliteTypeHandlerTests : TestBase public async Task Issue466_SqliteHatesOptimizations_Async()
{ {
[FactSqlite] using (var connection = GetSQLiteConnection())
public void Issue466_SqliteHatesOptimizations()
{ {
using (var connection = GetSQLiteConnection()) SqlMapper.ResetTypeHandlers();
{ var row = (await connection.QueryAsync<HazNameId>("select 42 as Id").ConfigureAwait(false)).First();
SqlMapper.ResetTypeHandlers(); Assert.Equal(42, row.Id);
var row = connection.Query<HazNameId>("select 42 as Id").First(); row = (await connection.QueryAsync<HazNameId>("select 42 as Id").ConfigureAwait(false)).First();
Assert.Equal(42, row.Id); Assert.Equal(42, row.Id);
row = connection.Query<HazNameId>("select 42 as Id").First();
Assert.Equal(42, row.Id);
SqlMapper.ResetTypeHandlers(); SqlMapper.ResetTypeHandlers();
row = connection.QueryFirst<HazNameId>("select 42 as Id"); row = await connection.QueryFirstAsync<HazNameId>("select 42 as Id").ConfigureAwait(false);
Assert.Equal(42, row.Id); Assert.Equal(42, row.Id);
row = connection.QueryFirst<HazNameId>("select 42 as Id"); row = await connection.QueryFirstAsync<HazNameId>("select 42 as Id").ConfigureAwait(false);
Assert.Equal(42, row.Id); Assert.Equal(42, row.Id);
}
} }
}
}
[FactSqlite] public class SqliteTests : SqliteTypeTestBase
public async Task Issue466_SqliteHatesOptimizations_Async() {
[FactSqlite]
public void DapperEnumValue_Sqlite()
{
using (var connection = GetSQLiteConnection())
{ {
using (var connection = GetSQLiteConnection()) Common.DapperEnumValue(connection);
{
SqlMapper.ResetTypeHandlers();
var row = (await connection.QueryAsync<HazNameId>("select 42 as Id").ConfigureAwait(false)).First();
Assert.Equal(42, row.Id);
row = (await connection.QueryAsync<HazNameId>("select 42 as Id").ConfigureAwait(false)).First();
Assert.Equal(42, row.Id);
SqlMapper.ResetTypeHandlers();
row = await connection.QueryFirstAsync<HazNameId>("select 42 as Id").ConfigureAwait(false);
Assert.Equal(42, row.Id);
row = await connection.QueryFirstAsync<HazNameId>("select 42 as Id").ConfigureAwait(false);
Assert.Equal(42, row.Id);
}
} }
} }
[FactSqlite] [FactSqlite]
public void Isse467_SqliteLikesParametersWithPrefix() public void Isse467_SqliteLikesParametersWithPrefix()
{ {
...@@ -89,32 +123,6 @@ private void Isse467_SqliteParameterNaming(bool prefix) ...@@ -89,32 +123,6 @@ private void Isse467_SqliteParameterNaming(bool prefix)
var i = Convert.ToInt32(cmd.ExecuteScalar()); var i = Convert.ToInt32(cmd.ExecuteScalar());
Assert.Equal(42, i); Assert.Equal(42, i);
} }
} }
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class FactSqliteAttribute : FactAttribute
{
public override string Skip
{
get { return unavailable ?? base.Skip; }
set { base.Skip = value; }
}
private static readonly string unavailable;
static FactSqliteAttribute()
{
try
{
using (GetSQLiteConnection())
{
}
}
catch (Exception ex)
{
unavailable = $"Sqlite is unavailable: {ex.Message}";
}
}
}
} }
} }
\ No newline at end of file
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class QueryMultipleTests : TestBase public sealed class SystemSqlClientQueryMultipleTests : QueryMultipleTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientQueryMultipleTests : QueryMultipleTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class QueryMultipleTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestQueryMultipleBuffered() public void TestQueryMultipleBuffered()
......
using System; using System;
using System.Data; using System.Data;
using System.Data.SqlClient;
using System.Globalization; using System.Globalization;
using Xunit; using Xunit;
using System.Data.Common;
#if !NETCOREAPP1_0 #if !NETCOREAPP1_0
using System.Threading; using System.Threading;
#endif #endif
namespace Dapper.Tests namespace Dapper.Tests
{ {
public abstract class TestBase : IDisposable public static class DatabaseProvider<TProvider> where TProvider : DatabaseProvider
{ {
protected static readonly bool IsAppVeyor = Environment.GetEnvironmentVariable("Appveyor")?.ToUpperInvariant() == "TRUE"; public static TProvider Instance { get; } = Activator.CreateInstance<TProvider>();
}
public static string ConnectionString => public abstract class DatabaseProvider
IsAppVeyor {
? @"Server=(local)\SQL2016;Database=tempdb;User ID=sa;Password=Password12!" public abstract DbProviderFactory Factory { get; }
: "Data Source=.;Initial Catalog=tempdb;Integrated Security=True";
protected SqlConnection _connection; public static bool IsAppVeyor { get; } = Environment.GetEnvironmentVariable("Appveyor")?.ToUpperInvariant() == "TRUE";
protected SqlConnection connection => _connection ?? (_connection = GetOpenConnection()); public virtual void Dispose() { }
public abstract string GetConnectionString();
public static SqlConnection GetOpenConnection(bool mars = false) public DbConnection GetOpenConnection()
{ {
var cs = ConnectionString; var conn = Factory.CreateConnection();
if (mars) conn.ConnectionString = GetConnectionString();
{ conn.Open();
var scsb = new SqlConnectionStringBuilder(cs) if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
{ return conn;
MultipleActiveResultSets = true
};
cs = scsb.ConnectionString;
}
var connection = new SqlConnection(cs);
connection.Open();
return connection;
} }
public SqlConnection GetClosedConnection() public DbConnection GetClosedConnection()
{ {
var conn = new SqlConnection(ConnectionString); var conn = Factory.CreateConnection();
conn.ConnectionString = GetConnectionString();
if (conn.State != ConnectionState.Closed) throw new InvalidOperationException("should be closed!"); if (conn.State != ConnectionState.Closed) throw new InvalidOperationException("should be closed!");
return conn; return conn;
} }
public DbParameter CreateRawParameter(string name, object value)
{
var p = Factory.CreateParameter();
p.ParameterName = name;
p.Value = value ?? DBNull.Value;
return p;
}
}
public abstract class SqlServerDatabaseProvider : DatabaseProvider
{
public override string GetConnectionString() =>
IsAppVeyor
? @"Server=(local)\SQL2016;Database=tempdb;User ID=sa;Password=Password12!"
: "Data Source=.;Initial Catalog=tempdb;Integrated Security=True";
public DbConnection GetOpenConnection(bool mars)
{
if (!mars) return GetOpenConnection();
var scsb = Factory.CreateConnectionStringBuilder();
scsb.ConnectionString = GetConnectionString();
((dynamic)scsb).MultipleActiveResultSets = true;
var conn = Factory.CreateConnection();
conn.ConnectionString = scsb.ConnectionString;
conn.Open();
if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
return conn;
}
}
public sealed class SystemSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => System.Data.SqlClient.SqlClientFactory.Instance;
}
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => Microsoft.Data.SqlClient.SqlClientFactory.Instance;
}
#endif
public abstract class TestBase<TProvider> : IDisposable where TProvider : DatabaseProvider
{
protected DbConnection GetOpenConnection() => Provider.GetOpenConnection();
protected DbConnection GetClosedConnection() => Provider.GetClosedConnection();
protected DbConnection _connection;
protected DbConnection connection => _connection ?? (_connection = Provider.GetOpenConnection());
public TProvider Provider { get; } = DatabaseProvider<TProvider>.Instance;
protected static CultureInfo ActiveCulture protected static CultureInfo ActiveCulture
{ {
#if NETCOREAPP1_0 #if NETCOREAPP1_0
...@@ -58,7 +102,10 @@ protected static CultureInfo ActiveCulture ...@@ -58,7 +102,10 @@ protected static CultureInfo ActiveCulture
static TestBase() static TestBase()
{ {
Console.WriteLine("Dapper: " + typeof(SqlMapper).AssemblyQualifiedName); Console.WriteLine("Dapper: " + typeof(SqlMapper).AssemblyQualifiedName);
Console.WriteLine("Using Connectionstring: {0}", ConnectionString); var provider = DatabaseProvider<TProvider>.Instance;
Console.WriteLine("Using Connectionstring: {0}", provider.GetConnectionString());
var factory = provider.Factory;
Console.WriteLine("Using Provider: {0}", factory.GetType().FullName);
#if NETCOREAPP1_0 #if NETCOREAPP1_0
Console.WriteLine("CoreCLR (netcoreapp1.0)"); Console.WriteLine("CoreCLR (netcoreapp1.0)");
#else #else
...@@ -77,14 +124,15 @@ static TestBase() ...@@ -77,14 +124,15 @@ static TestBase()
#endif #endif
} }
public void Dispose() public virtual void Dispose()
{ {
_connection?.Dispose(); _connection?.Dispose();
_connection = null;
Provider?.Dispose();
} }
} }
[CollectionDefinition(Name, DisableParallelization = true)] public static class NonParallelDefinition
public class NonParallelDefinition : TestBase
{ {
public const string Name = "NonParallel"; public const string Name = "NonParallel";
} }
......
...@@ -7,7 +7,11 @@ ...@@ -7,7 +7,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class TransactionTests : TestBase public sealed class SystemSqlClientTransactionTests : TransactionTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientTransactionTests : TransactionTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TransactionTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestTransactionCommit() public void TestTransactionCommit()
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class TupleTests : TestBase public sealed class SystemSqlClientTupleTests : TupleTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientTupleTests : TupleTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TupleTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TupleStructParameter_Fails_HelpfulMessage() public void TupleStructParameter_Fails_HelpfulMessage()
......
...@@ -9,7 +9,13 @@ ...@@ -9,7 +9,13 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
[Collection(NonParallelDefinition.Name)] [Collection(NonParallelDefinition.Name)]
public class TypeHandlerTests : TestBase public sealed class SystemSqlClientTypeHandlerTests : TypeHandlerTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
[Collection(NonParallelDefinition.Name)]
public sealed class MicrosoftSqlClientTypeHandlerTests : TypeHandlerTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class TypeHandlerTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void TestChangingDefaultStringTypeMappingToAnsiString() public void TestChangingDefaultStringTypeMappingToAnsiString()
......
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
namespace Dapper.Tests namespace Dapper.Tests
{ {
public class XmlTests : TestBase public sealed class SystemSqlClientXmlTests : XmlTests<SystemSqlClientProvider> { }
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientXmlTests : XmlTests<MicrosoftSqlClientProvider> { }
#endif
public abstract class XmlTests<TProvider> : TestBase<TProvider> where TProvider : DatabaseProvider
{ {
[Fact] [Fact]
public void CommonXmlTypesSupported() public void CommonXmlTypesSupported()
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
<Title>Dapper</Title> <Title>Dapper</Title>
<Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description> <Description>A high performance Micro-ORM supporting SQL Server, MySQL, Sqlite, SqlCE, Firebird etc..</Description>
<Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors> <Authors>Sam Saffron;Marc Gravell;Nick Craver</Authors>
<TargetFrameworks>net451;netstandard1.3;netstandard2.0</TargetFrameworks> <TargetFrameworks>net451;netstandard1.3;netstandard2.0;netcoreapp2.1</TargetFrameworks>
</PropertyGroup> </PropertyGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net451'"> <ItemGroup Condition="'$(TargetFramework)' == 'net451'">
<Reference Include="System" /> <Reference Include="System" />
...@@ -15,9 +15,15 @@ ...@@ -15,9 +15,15 @@
<Reference Include="Microsoft.CSharp" /> <Reference Include="Microsoft.CSharp" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'"> <ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' OR '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" /> <PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" /> <PackageReference Include="System.Reflection.TypeExtensions" Version="4.4.0" />
<!-- it would be nice to use System.Data.Common here, but we need SqlClient for SqlDbType in 1.3, and legacy SqlDataRecord API-->
<!--<PackageReference Include="System.Data.Common" Version="4.3.0" />-->
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<PackageReference Include="System.Data.SqlClient" Version="4.6.0" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' "> <ItemGroup Condition=" '$(TargetFramework)' == 'netstandard1.3' ">
<PackageReference Include="System.Collections.Concurrent" Version="4.3.0" /> <PackageReference Include="System.Collections.Concurrent" Version="4.3.0" />
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
namespace Dapper namespace Dapper
{ {
internal sealed class SqlDataRecordHandler : SqlMapper.ITypeHandler internal sealed class SqlDataRecordHandler<T> : SqlMapper.ITypeHandler
#if !NETSTANDARD1_3
where T : IDataRecord
#endif
{ {
public object Parse(Type destinationType, object value) public object Parse(Type destinationType, object value)
{ {
...@@ -13,7 +16,7 @@ public object Parse(Type destinationType, object value) ...@@ -13,7 +16,7 @@ public object Parse(Type destinationType, object value)
public void SetValue(IDbDataParameter parameter, object value) public void SetValue(IDbDataParameter parameter, object value)
{ {
SqlDataRecordListTVPParameter.Set(parameter, value as IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>, null); SqlDataRecordListTVPParameter<T>.Set(parameter, value as IEnumerable<T>, null);
} }
} }
} }
using System.Collections.Generic; using System;
using System.Collections;
using System.Collections.Generic;
using System.Data; using System.Data;
using System.Linq; using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
namespace Dapper namespace Dapper
{ {
/// <summary> /// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a SqlDataRecordListTVPParameter /// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a SqlDataRecordListTVPParameter
/// </summary> /// </summary>
internal sealed class SqlDataRecordListTVPParameter : SqlMapper.ICustomQueryParameter internal sealed class SqlDataRecordListTVPParameter<T> : SqlMapper.ICustomQueryParameter
#if !NETSTANDARD1_3
where T : IDataRecord
#endif
{ {
private readonly IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data; private readonly IEnumerable<T> data;
private readonly string typeName; private readonly string typeName;
/// <summary> /// <summary>
/// Create a new instance of <see cref="SqlDataRecordListTVPParameter"/>. /// Create a new instance of <see cref="SqlDataRecordListTVPParameter&lt;T&gt;"/>.
/// </summary> /// </summary>
/// <param name="data">The data records to convert into TVPs.</param> /// <param name="data">The data records to convert into TVPs.</param>
/// <param name="typeName">The parameter type name.</param> /// <param name="typeName">The parameter type name.</param>
public SqlDataRecordListTVPParameter(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data, string typeName) public SqlDataRecordListTVPParameter(IEnumerable<T> data, string typeName)
{ {
this.data = data; this.data = data;
this.typeName = typeName; this.typeName = typeName;
...@@ -30,14 +37,72 @@ void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string nam ...@@ -30,14 +37,72 @@ void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string nam
command.Parameters.Add(param); command.Parameters.Add(param);
} }
internal static void Set(IDbDataParameter parameter, IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> data, string typeName) internal static void Set(IDbDataParameter parameter, IEnumerable<T> data, string typeName)
{ {
parameter.Value = data != null && data.Any() ? data : null; parameter.Value = data != null && data.Any() ? data : null;
if (parameter is System.Data.SqlClient.SqlParameter sqlParam) StructuredHelper.ConfigureTVP(parameter, typeName);
}
}
static class StructuredHelper
{
private static readonly Hashtable s_udt = new Hashtable(), s_tvp = new Hashtable();
private static Action<IDbDataParameter, string> GetUDT(Type type)
=> (Action<IDbDataParameter, string>)s_udt[type] ?? SlowGetHelper(type, s_udt, "UdtTypeName", 29); // 29 = SqlDbType.Udt (avoiding ref)
private static Action<IDbDataParameter, string> GetTVP(Type type)
=> (Action<IDbDataParameter, string>)s_tvp[type] ?? SlowGetHelper(type, s_tvp, "TypeName", 30); // 30 = SqlDbType.Structured (avoiding ref)
static Action<IDbDataParameter, string> SlowGetHelper(Type type, Hashtable hashtable, string nameProperty, int sqlDbType)
{
lock (hashtable)
{ {
sqlParam.SqlDbType = SqlDbType.Structured; var helper = (Action<IDbDataParameter, string>)hashtable[type];
sqlParam.TypeName = typeName; if (helper == null)
{
helper = CreateFor(type, nameProperty, sqlDbType);
hashtable.Add(type, helper);
}
return helper;
} }
} }
static Action<IDbDataParameter, string> CreateFor(Type type, string nameProperty, int sqlDbType)
{
var name = type.GetProperty(nameProperty, BindingFlags.Public | BindingFlags.Instance);
if (name == null || !name.CanWrite)
{
return (p, n) => { };
}
var dbType = type.GetProperty("SqlDbType", BindingFlags.Public | BindingFlags.Instance);
if (dbType != null && !dbType.CanWrite) dbType = null;
var dm = new DynamicMethod(nameof(CreateFor) + "_" + type.Name, null,
new[] { typeof(IDbDataParameter), typeof(string) }, true);
var il = dm.GetILGenerator();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Ldarg_1);
il.EmitCall(OpCodes.Callvirt, name.GetSetMethod(), null);
if (dbType != null)
{
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Ldc_I4, sqlDbType);
il.EmitCall(OpCodes.Callvirt, dbType.GetSetMethod(), null);
}
il.Emit(OpCodes.Ret);
return (Action<IDbDataParameter, string>)dm.CreateDelegate(typeof(Action<IDbDataParameter, string>));
}
// this needs to be done per-provider; "dynamic" doesn't work well on all runtimes, although that
// would be a fair option otherwise
internal static void ConfigureUDT(IDbDataParameter parameter, string typeName)
=> GetUDT(parameter.GetType())(parameter, typeName);
internal static void ConfigureTVP(IDbDataParameter parameter, string typeName)
=> GetTVP(parameter.GetType())(parameter, typeName);
} }
} }
...@@ -1101,7 +1101,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn, ...@@ -1101,7 +1101,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// </code> /// </code>
/// </example> /// </example>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) => public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) =>
ExecuteReaderImplAsync(cnn, new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered), CommandBehavior.Default); ExecuteWrappedReaderImplAsync(cnn, new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered), CommandBehavior.Default);
/// <summary> /// <summary>
/// Execute parameterized SQL and return an <see cref="IDataReader"/>. /// Execute parameterized SQL and return an <see cref="IDataReader"/>.
...@@ -1114,7 +1114,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn, ...@@ -1114,7 +1114,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// or <see cref="T:DataSet"/>. /// or <see cref="T:DataSet"/>.
/// </remarks> /// </remarks>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command) => public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command) =>
ExecuteReaderImplAsync(cnn, command, CommandBehavior.Default); ExecuteWrappedReaderImplAsync(cnn, command, CommandBehavior.Default);
/// <summary> /// <summary>
/// Execute parameterized SQL and return an <see cref="IDataReader"/>. /// Execute parameterized SQL and return an <see cref="IDataReader"/>.
...@@ -1128,26 +1128,27 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn, ...@@ -1128,26 +1128,27 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
/// or <see cref="T:DataSet"/>. /// or <see cref="T:DataSet"/>.
/// </remarks> /// </remarks>
public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior) => public static Task<IDataReader> ExecuteReaderAsync(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior) =>
ExecuteReaderImplAsync(cnn, command, commandBehavior); ExecuteWrappedReaderImplAsync(cnn, command, commandBehavior);
private static async Task<IDataReader> ExecuteReaderImplAsync(IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior) private static async Task<IDataReader> ExecuteWrappedReaderImplAsync(IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior)
{ {
Action<IDbCommand, object> paramReader = GetParameterReader(cnn, ref command); Action<IDbCommand, object> paramReader = GetParameterReader(cnn, ref command);
DbCommand cmd = null; DbCommand cmd = null;
bool wasClosed = cnn.State == ConnectionState.Closed; bool wasClosed = cnn.State == ConnectionState.Closed, disposeCommand = true;
try try
{ {
cmd = command.TrySetupAsyncCommand(cnn, paramReader); cmd = command.TrySetupAsyncCommand(cnn, paramReader);
if (wasClosed) await cnn.TryOpenAsync(command.CancellationToken).ConfigureAwait(false); if (wasClosed) await cnn.TryOpenAsync(command.CancellationToken).ConfigureAwait(false);
var reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, commandBehavior, command.CancellationToken).ConfigureAwait(false); var reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, commandBehavior, command.CancellationToken).ConfigureAwait(false);
wasClosed = false; wasClosed = false;
return reader; disposeCommand = false;
return WrappedReader.Create(cmd, reader);
} }
finally finally
{ {
if (wasClosed) cnn.Close(); if (wasClosed) cnn.Close();
cmd?.Dispose(); if (cmd != null && disposeCommand) cmd.Dispose();
} }
} }
......
...@@ -237,7 +237,7 @@ private static void ResetTypeHandlers(bool clone) ...@@ -237,7 +237,7 @@ private static void ResetTypeHandlers(bool clone)
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
private static void AddSqlDataRecordsTypeHandler(bool clone) private static void AddSqlDataRecordsTypeHandler(bool clone)
{ {
AddTypeHandlerImpl(typeof(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>), new SqlDataRecordHandler(), clone); AddTypeHandlerImpl(typeof(IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord>), new SqlDataRecordHandler<Microsoft.SqlServer.Server.SqlDataRecord>(), clone);
} }
/// <summary> /// <summary>
...@@ -599,7 +599,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje ...@@ -599,7 +599,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje
{ {
var command = new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered); var command = new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered);
var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd); var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader); return WrappedReader.Create(dbcmd, reader);
} }
/// <summary> /// <summary>
...@@ -615,7 +615,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje ...@@ -615,7 +615,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, string sql, obje
public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command) public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command)
{ {
var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd); var reader = ExecuteReaderImpl(cnn, ref command, CommandBehavior.Default, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader); return WrappedReader.Create(dbcmd, reader);
} }
/// <summary> /// <summary>
...@@ -632,7 +632,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinitio ...@@ -632,7 +632,7 @@ public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinitio
public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior) public static IDataReader ExecuteReader(this IDbConnection cnn, CommandDefinition command, CommandBehavior commandBehavior)
{ {
var reader = ExecuteReaderImpl(cnn, ref command, commandBehavior, out IDbCommand dbcmd); var reader = ExecuteReaderImpl(cnn, ref command, commandBehavior, out IDbCommand dbcmd);
return new WrappedReader(dbcmd, reader); return WrappedReader.Create(dbcmd, reader);
} }
/// <summary> /// <summary>
...@@ -3785,13 +3785,24 @@ public static void SetTypeName(this DataTable table, string typeName) ...@@ -3785,13 +3785,24 @@ public static void SetTypeName(this DataTable table, string typeName)
table?.ExtendedProperties[DataTableTypeNameKey] as string; table?.ExtendedProperties[DataTableTypeNameKey] as string;
#endif #endif
/// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter.
/// </summary>
/// <param name="list">The list of records to convert to TVPs.</param>
/// <param name="typeName">The sql parameter type name.</param>
public static ICustomQueryParameter AsTableValuedParameter<T>(this IEnumerable<T> list, string typeName = null) where T : IDataRecord =>
new SqlDataRecordListTVPParameter<T>(list, typeName);
/*
/// <summary> /// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter. /// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter.
/// </summary> /// </summary>
/// <param name="list">The list of records to convert to TVPs.</param> /// <param name="list">The list of records to convert to TVPs.</param>
/// <param name="typeName">The sql parameter type name.</param> /// <param name="typeName">The sql parameter type name.</param>
public static ICustomQueryParameter AsTableValuedParameter(this IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> list, string typeName = null) => public static ICustomQueryParameter AsTableValuedParameter(this IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> list, string typeName = null) =>
new SqlDataRecordListTVPParameter(list, typeName); new SqlDataRecordListTVPParameter<Microsoft.SqlServer.Server.SqlDataRecord>(list, typeName);
// ^^^ retained to avoid missing-method-exception; can presumably drop in a "major"
*/
// one per thread // one per thread
[ThreadStatic] [ThreadStatic]
......
using System; using System.Data;
using System.Data;
using System.Reflection;
#if !NETSTANDARD1_3 #if !NETSTANDARD1_3
namespace Dapper namespace Dapper
...@@ -30,17 +28,6 @@ public TableValuedParameter(DataTable table, string typeName) ...@@ -30,17 +28,6 @@ public TableValuedParameter(DataTable table, string typeName)
this.typeName = typeName; this.typeName = typeName;
} }
private static readonly Action<System.Data.SqlClient.SqlParameter, string> setTypeName;
static TableValuedParameter()
{
var prop = typeof(System.Data.SqlClient.SqlParameter).GetProperty("TypeName", BindingFlags.Instance | BindingFlags.Public);
if (prop != null && prop.PropertyType == typeof(string) && prop.CanWrite)
{
setTypeName = (Action<System.Data.SqlClient.SqlParameter, string>)
Delegate.CreateDelegate(typeof(Action<System.Data.SqlClient.SqlParameter, string>), prop.GetSetMethod());
}
}
void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string name) void SqlMapper.ICustomQueryParameter.AddParameter(IDbCommand command, string name)
{ {
var param = command.CreateParameter(); var param = command.CreateParameter();
...@@ -58,12 +45,8 @@ internal static void Set(IDbDataParameter parameter, DataTable table, string typ ...@@ -58,12 +45,8 @@ internal static void Set(IDbDataParameter parameter, DataTable table, string typ
{ {
typeName = table.GetTypeName(); typeName = table.GetTypeName();
} }
if (!string.IsNullOrEmpty(typeName) && (parameter is System.Data.SqlClient.SqlParameter sqlParam)) if (!string.IsNullOrEmpty(typeName)) StructuredHelper.ConfigureTVP(parameter, typeName);
{
setTypeName?.Invoke(sqlParam, typeName);
sqlParam.SqlDbType = SqlDbType.Structured;
}
} }
} }
} }
#endif #endif
\ No newline at end of file
...@@ -33,11 +33,7 @@ void ITypeHandler.SetValue(IDbDataParameter parameter, object value) ...@@ -33,11 +33,7 @@ void ITypeHandler.SetValue(IDbDataParameter parameter, object value)
#pragma warning disable 0618 #pragma warning disable 0618
parameter.Value = SanitizeParameterValue(value); parameter.Value = SanitizeParameterValue(value);
#pragma warning restore 0618 #pragma warning restore 0618
if (parameter is System.Data.SqlClient.SqlParameter && !(value is DBNull)) if(!(value is DBNull)) StructuredHelper.ConfigureUDT(parameter, udtTypeName);
{
((System.Data.SqlClient.SqlParameter)parameter).SqlDbType = SqlDbType.Udt;
((System.Data.SqlClient.SqlParameter)parameter).UdtTypeName = udtTypeName;
}
} }
} }
#endif #endif
......
This diff is collapsed.
{
"sdk": {
"version": "2.2.203"
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment