Commit 67fc0d80 authored by mgravell's avatar mgravell

fix tests for TVPs with SqlDataRecord

parent 12312d2a
......@@ -43,7 +43,19 @@ void SqlMapper.IDynamicParameters.AddParameters(IDbCommand command, SqlMapper.Id
}
}
private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecordList(IEnumerable<int> numbers)
private static IEnumerable<IDataRecord> CreateSqlDataRecordList(IDbCommand command, IEnumerable<int> numbers)
{
if (command is System.Data.SqlClient.SqlCommand) return CreateSqlDataRecordList_SD(numbers);
if (command is Microsoft.Data.SqlClient.SqlCommand) return CreateSqlDataRecordList_MD(numbers);
throw new ArgumentException(nameof(command));
}
private static IEnumerable<IDataRecord> CreateSqlDataRecordList(IDbConnection connection, IEnumerable<int> numbers)
{
if (connection is System.Data.SqlClient.SqlConnection) return CreateSqlDataRecordList_SD(numbers);
if (connection is Microsoft.Data.SqlClient.SqlConnection) return CreateSqlDataRecordList_MD(numbers);
throw new ArgumentException(nameof(connection));
}
private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecordList_SD(IEnumerable<int> numbers)
{
var number_list = new List<Microsoft.SqlServer.Server.SqlDataRecord>();
......@@ -60,7 +72,25 @@ private static List<Microsoft.SqlServer.Server.SqlDataRecord> CreateSqlDataRecor
return number_list;
}
private static List<Microsoft.Data.SqlClient.Server.SqlDataRecord> CreateSqlDataRecordList_MD(IEnumerable<int> numbers)
{
var number_list = new List<Microsoft.Data.SqlClient.Server.SqlDataRecord>();
// Create an SqlMetaData object that describes our table type.
Microsoft.Data.SqlClient.Server.SqlMetaData[] tvp_definition = { new Microsoft.Data.SqlClient.Server.SqlMetaData("n", SqlDbType.Int) };
foreach (int n in numbers)
{
// Create a new record, using the metadata array above.
var rec = new Microsoft.Data.SqlClient.Server.SqlDataRecord(tvp_definition);
rec.SetInt32(0, n); // Set the value.
number_list.Add(rec); // Add it to the list.
}
return number_list;
}
private class IntDynamicParam : SqlMapper.IDynamicParameters
{
......@@ -72,16 +102,11 @@ public IntDynamicParam(IEnumerable<int> numbers)
public void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
command.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
var number_list = CreateSqlDataRecordList(command, numbers);
// Add the table parameter.
var p = sqlCommand.Parameters.Add("ints", SqlDbType.Structured);
p.Direction = ParameterDirection.Input;
p.TypeName = "int_list_type";
p.Value = number_list;
AddStructured(command, number_list);
}
}
......@@ -95,17 +120,35 @@ public IntCustomParam(IEnumerable<int> numbers)
public void AddParameter(IDbCommand command, string name)
{
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
command.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
var number_list = CreateSqlDataRecordList(command, numbers);
// Add the table parameter.
var p = sqlCommand.Parameters.Add(name, SqlDbType.Structured);
AddStructured(command, number_list);
}
}
private static IDbDataParameter AddStructured(IDbCommand command, object value)
{
if (command is System.Data.SqlClient.SqlCommand sdcmd)
{
var p = sdcmd.Parameters.Add("integers", SqlDbType.Structured);
p.Direction = ParameterDirection.Input;
p.TypeName = "int_list_type";
p.Value = number_list;
p.Value = value;
return p;
}
else if (command is Microsoft.Data.SqlClient.SqlCommand mdcmd)
{
var p = mdcmd.Parameters.Add("integers", SqlDbType.Structured);
p.Direction = ParameterDirection.Input;
p.TypeName = "int_list_type";
p.Value = value;
return p;
}
else
throw new ArgumentException(nameof(command));
}
/* TODO:
......@@ -319,16 +362,12 @@ public new void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
base.AddParameters(command, identity);
var sqlCommand = (System.Data.SqlClient.SqlCommand)command;
sqlCommand.CommandType = CommandType.StoredProcedure;
command.CommandType = CommandType.StoredProcedure;
var number_list = CreateSqlDataRecordList(numbers);
var number_list = CreateSqlDataRecordList(command, numbers);
// Add the table parameter.
var p = sqlCommand.Parameters.Add("ints", SqlDbType.Structured);
p.Direction = ParameterDirection.Input;
p.TypeName = "int_list_type";
p.Value = number_list;
AddStructured(command, number_list);
}
}
......@@ -374,7 +413,7 @@ public void TestSqlDataRecordListParametersWithAsTableValuedParameter()
connection.Execute("CREATE TYPE int_list_type AS TABLE (n int NOT NULL PRIMARY KEY)");
connection.Execute("CREATE PROC get_ints @integers int_list_type READONLY AS select * from @integers");
var records = CreateSqlDataRecordList(new int[] { 1, 2, 3 });
var records = CreateSqlDataRecordList(connection, new int[] { 1, 2, 3 });
var nums = connection.Query<int>("get_ints", new { integers = records.AsTableValuedParameter() }, commandType: CommandType.StoredProcedure).ToList();
Assert.Equal(new int[] { 1, 2, 3 }, nums);
......@@ -414,7 +453,7 @@ public void TestEmptySqlDataRecordListParametersWithAsTableValuedParameter()
connection.Execute("CREATE PROC get_ints @integers int_list_type READONLY AS select * from @integers");
var emptyRecord = CreateSqlDataRecordList(Enumerable.Empty<int>());
var emptyRecord = CreateSqlDataRecordList(connection, Enumerable.Empty<int>());
var nums = connection.Query<int>("get_ints", new { integers = emptyRecord.AsTableValuedParameter() }, commandType: CommandType.StoredProcedure).ToList();
Assert.True(nums.Count == 0);
......@@ -441,14 +480,28 @@ public void TestSqlDataRecordListParametersWithTypeHandlers()
connection.Execute("CREATE PROC get_ints @integers int_list_type READONLY AS select * from @integers");
// Variable type has to be IEnumerable<SqlDataRecord> for TypeHandler to kick in.
IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> records = CreateSqlDataRecordList(new int[] { 1, 2, 3 });
object args;
if (connection is System.Data.SqlClient.SqlConnection)
{
IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> records = CreateSqlDataRecordList_SD(new int[] { 1, 2, 3 });
args = new { integers = records };
}
else if (connection is Microsoft.Data.SqlClient.SqlConnection)
{
IEnumerable<Microsoft.Data.SqlClient.Server.SqlDataRecord> records = CreateSqlDataRecordList_MD(new int[] { 1, 2, 3 });
args = new { integers = records };
}
else
{
throw new ArgumentException(nameof(connection));
}
var nums = connection.Query<int>("get_ints", new { integers = records }, commandType: CommandType.StoredProcedure).ToList();
var nums = connection.Query<int>("get_ints", args, commandType: CommandType.StoredProcedure).ToList();
Assert.Equal(new int[] { 1, 2, 3 }, nums);
try
{
connection.Query<int>("select * from @integers", new { integers = records }).First();
connection.Query<int>("select * from @integers", args).First();
throw new InvalidOperationException();
}
catch (Exception ex)
......
......@@ -3795,7 +3795,6 @@ public static void SetTypeName(this DataTable table, string typeName)
public static ICustomQueryParameter AsTableValuedParameter<T>(this IEnumerable<T> list, string typeName = null) where T : IDataRecord =>
new SqlDataRecordListTVPParameter<T>(list, typeName);
/*
/// <summary>
/// Used to pass a IEnumerable&lt;SqlDataRecord&gt; as a TableValuedParameter.
/// </summary>
......@@ -3804,7 +3803,6 @@ public static void SetTypeName(this DataTable table, string typeName)
public static ICustomQueryParameter AsTableValuedParameter(this IEnumerable<Microsoft.SqlServer.Server.SqlDataRecord> list, string typeName = null) =>
new SqlDataRecordListTVPParameter<Microsoft.SqlServer.Server.SqlDataRecord>(list, typeName);
// ^^^ retained to avoid missing-method-exception; can presumably drop in a "major"
*/
// one per thread
[ThreadStatic]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment