Commit 8d0eb4de authored by simon.cropp's avatar simon.cropp

avoid some casting by more use of generics

parent 18c01151
...@@ -400,22 +400,22 @@ private static CacheInfo GetCacheInfo(object param, Identity identity) ...@@ -400,22 +400,22 @@ private static CacheInfo GetCacheInfo(object param, Identity identity)
// dynamic is passed in as Object ... by c# design // dynamic is passed in as Object ... by c# design
if (typeof(T) == typeof(object) || typeof(T) == typeof(ExpandoObject)) if (typeof(T) == typeof(object) || typeof(T) == typeof(ExpandoObject))
{ {
oDeserializer = GetDynamicDeserializer(reader,startBound, length, returnNullIfFirstMissing); return GetDynamicDeserializer<T>(reader,startBound, length, returnNullIfFirstMissing);
} }
else if (typeof(T).IsClass && typeof(T) != typeof(string)) else if (typeof(T).IsClass && typeof(T) != typeof(string))
{ {
oDeserializer = GetClassDeserializer<T>(reader, startBound, length, returnNullIfFirstMissing: returnNullIfFirstMissing); return GetClassDeserializer<T>(reader, startBound, length, returnNullIfFirstMissing: returnNullIfFirstMissing);
} }
else else
{ {
oDeserializer = GetStructDeserializer<T>(reader); return GetStructDeserializer<T>(reader);
} }
var deserializer = (Func<IDataReader, T>)oDeserializer; var deserializer = (Func<IDataReader, T>)oDeserializer;
return deserializer; return deserializer;
} }
private static object GetDynamicDeserializer(IDataRecord reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false) private static Func<IDataReader, T> GetDynamicDeserializer<T>(IDataRecord reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false)
{ {
var colNames = new List<string>(); var colNames = new List<string>();
...@@ -429,7 +429,7 @@ private static object GetDynamicDeserializer(IDataRecord reader, int startBound ...@@ -429,7 +429,7 @@ private static object GetDynamicDeserializer(IDataRecord reader, int startBound
colNames.Add(reader.GetName(i)); colNames.Add(reader.GetName(i));
} }
Func<IDataReader, ExpandoObject> rval = return
r => r =>
{ {
IDictionary<string, object> row = new ExpandoObject(); IDictionary<string, object> row = new ExpandoObject();
...@@ -442,15 +442,14 @@ private static object GetDynamicDeserializer(IDataRecord reader, int startBound ...@@ -442,15 +442,14 @@ private static object GetDynamicDeserializer(IDataRecord reader, int startBound
row[colName] = tmp; row[colName] = tmp;
if (returnNullIfFirstMissing && first && tmp == null) if (returnNullIfFirstMissing && first && tmp == null)
{ {
return null; return default(T);
} }
i++; i++;
first = false; first = false;
} }
return (ExpandoObject)row; return (T)row;
}; };
return rval;
} }
[Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
[Obsolete("This method is for internal usage only", true)] [Obsolete("This method is for internal usage only", true)]
...@@ -648,11 +647,9 @@ private static int ExecuteCommand(IDbConnection cnn, IDbTransaction tranaction, ...@@ -648,11 +647,9 @@ private static int ExecuteCommand(IDbConnection cnn, IDbTransaction tranaction,
} }
} }
private static object GetStructDeserializer<T>(IDataReader reader) private static Func<IDataReader, T> GetStructDeserializer<T>(IDataReader reader)
{ {
Func<IDataReader, T> deserializer = null; return r =>
deserializer = r =>
{ {
var val = r.GetValue(0); var val = r.GetValue(0);
if (val == DBNull.Value) if (val == DBNull.Value)
...@@ -661,7 +658,6 @@ private static object GetStructDeserializer<T>(IDataReader reader) ...@@ -661,7 +658,6 @@ private static object GetStructDeserializer<T>(IDataReader reader)
} }
return (T)val; return (T)val;
}; };
return deserializer;
} }
public static Func<IDataReader, T> GetClassDeserializer<T>(IDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false) public static Func<IDataReader, T> GetClassDeserializer<T>(IDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false)
...@@ -796,7 +792,7 @@ public IEnumerable<T> Read<T>() ...@@ -796,7 +792,7 @@ public IEnumerable<T> Read<T>()
if (reader == null) throw new ObjectDisposedException(GetType().Name); if (reader == null) throw new ObjectDisposedException(GetType().Name);
if (consumed) throw new InvalidOperationException("Each grid can only be iterated once"); if (consumed) throw new InvalidOperationException("Each grid can only be iterated once");
var identity = new Identity(sql, connection, typeof(T), null); var identity = new Identity(sql, connection, typeof(T), null);
var deserializer = SqlMapper.GetDeserializer<T>(identity, reader); var deserializer = GetDeserializer<T>(identity, reader);
consumed = true; consumed = true;
return ReadDeferred(gridIndex, deserializer); return ReadDeferred(gridIndex, deserializer);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment