这两天空闲时间研究了一下Linq 的提供器,简单的实现了一下,代码写的很乱,也没有注释,也没怎么对代码进行设计,因此有很多的临时变量和有些不必要的操作,但注重的是实现原理吧,微软的Linq to SQL实现水很深,这个例子只是简单的实现select和where,其他的没有实现,并且对于where查询,只支持有限的==、>、<,不过这个不重要,如果需要可以添加对应的实现。
先把代码记录下来吧,以后有时间再优化下代码和添加些注释。
IQueryable的实现:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Collections;
namespace SimpleLinq2Sql
{
public class CustomTable<T> : IQueryable<T>
{
private Type _ElementType = null;
private Expression _Expression = null;
private IQueryProvider _Provider = null;
public Type ElementType
{
get { return _ElementType; }
}
public Expression Expression
{
get { return _Expression; }
}
public IQueryProvider Provider
{
get { return _Provider; }
}
public CustomTable(Expression expression, IQueryProvider provider)
{
if (provider == null)
throw new Exception("provider can't be null");
_ElementType = typeof(T);
_Expression = expression;
_Provider = provider;
}
public CustomTable()
: this(null, new CustomProvider())
{
_Expression = Expression.Constant(this);
}
public IEnumerator<T> GetEnumerator()
{
return (Provider.Execute<IEnumerable<T>>(Expression)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public override string ToString()
{
return _Provider.ToString();
}
}
}
IQueryProvider的实现:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using System.Data;
using System.Data.SqlClient;
namespace SimpleLinq2Sql
{
public class CustomProvider : IQueryProvider
{
private string sql = "";
private int count = 0;
private string tableName = "";
private string selector = "";
private string where = "";
private Type _PreType = null;
private Type _ElementType = null;
public IQueryable<T> CreateQuery<T>(Expression expression)
{
_ElementType = typeof(T);
SetQueryText(expression);
count++;
return new CustomTable<T>(expression, this);
}
public IQueryable CreateQuery(Expression expression)
{
_ElementType = expression.Type.GetGenericArguments()[0];
SetQueryText(expression);
count++;
object[] args = new object[] { expression, this };
return (IQueryable)Activator.CreateInstance(typeof(CustomTable<>).MakeGenericType(_ElementType), args);
}
public T Execute<T>(Expression expression)
{
return (T)ExecuteSql(expression);
}
public object Execute(Expression expression)
{
return ExecuteSql(expression);
}
private void SetQueryText(Expression expression)
{
MethodCallExpression call = (MethodCallExpression)expression;
Expression first = call.Arguments[0];
Expression second = call.Arguments[1];
SetTableName(first);
if (call.Method.Name == "Select")
{
where = " ";
}
else if (call.Method.Name == "Where")
{
selector = "select " + "t" + count + ".* ";
}
ProcessExpression(second);
sql = selector + " from " + tableName + " " + where;
}
private void SetTableName(Expression expression)
{
if (expression is ConstantExpression)
{
_PreType = expression.Type.GetGenericArguments()[0];
tableName = MapHelper.GetTableName(_PreType) + " as t" + count + " ";
}
if (expression is MethodCallExpression)
{
_PreType = expression.Type.GetGenericArguments()[0];
tableName = "( " + sql + " ) as t" + count + " ";
}
}
void ProcessExpression(Expression expression)
{
if (expression is UnaryExpression)
{
UnaryExpression tmp = (UnaryExpression)expression;
ProcessExpression(tmp.Operand);
}
if (expression is LambdaExpression)
{
ProcessExpression(((LambdaExpression)expression).Body);
}
if (expression is BinaryExpression)
{
ProcessBinary((BinaryExpression)expression);
}
if (expression is NewExpression)
{
ProcessNew((NewExpression)expression);
}
}
void ProcessBinary(BinaryExpression expression)
{
string membername = "";
string propertyname = "";
object value = "";
string ope = "";
if (expression.Left is BinaryExpression || expression.Right is BinaryExpression)
{
throw new Exception("only be one binary");
}
if (expression.Left is MemberExpression)
{
MemberExpression tmp = (MemberExpression)expression.Left;
propertyname = tmp.Member.Name;
membername = MapHelper.GetColumnName(_PreType, propertyname);
}
if (expression.Right is ConstantExpression)
{
ConstantExpression tmp = (ConstantExpression)expression.Right;
value = tmp.Value;
}
if (expression.NodeType == ExpressionType.Equal)
{
ope = " = ";
}
if (expression.NodeType == ExpressionType.LessThan)
{
ope = " < ";
}
if (expression.NodeType == ExpressionType.GreaterThan)
{
ope = " > ";
}
Type type = MapHelper.GetColumnType(_PreType, propertyname);
switch (type.Name)
{
case "Int32":
case "Single":
case "Double":
where += " where t" + count + "." + membername + ope + value;
break;
case "String":
case "DateTime":
where += " where t" + count + "." + membername + ope + "'" + value + "'";
break;
}
}
void ProcessNew(NewExpression expression)
{
selector = "select ";
List<string> newName = new List<String>();
List<string> oldName = new List<string>();
foreach (MemberInfo mi in expression.Members)
{
newName.Add(mi.Name);
}
foreach (MemberExpression arg in expression.Arguments)
{
oldName.Add(arg.Member.Name);
}
for (int i = 0; i < oldName.Count; i++)
{
if (newName[i] == oldName[i])
{
selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + ",";
}
else
{
selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + " as " + newName[i] + " ,";
}
}
selector = selector.Substring(0, selector.Length - 1);
}
private object ExecuteSql(Expression expression)
{
DataSet ds = new DataSet();
using (SqlConnection connection = new SqlConnection("Data Source=.;Initial Catalog=TestLinq;Integrated Security=True")) //这里写死了数据库连接
{
connection.Open();
SqlCommand cmd = new SqlCommand(sql, connection);
SqlDataAdapter da = new SqlDataAdapter(cmd);
da.Fill(ds);
}
return Table2Entity.ConvertFromTable(ds.Tables[0], _ElementType); ;
}
public override string ToString()
{
return sql;
}
}
}
实体、属性与数据库中的表、列映射帮助类
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Reflection;
namespace SimpleLinq2Sql
{
public static class MapHelper
{
public static string GetTableName(Type type)
{
if (!type.IsDefined(typeof(TableAttribute), false)) throw new Exception("");
TableAttribute ta = Attribute.GetCustomAttribute(type, typeof(TableAttribute)) as TableAttribute;
return ta.TableName;
}
public static string GetColumnName(Type type, string propertyName)
{
PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
if (pi == null) throw new Exception("");
if (!pi.IsDefined(typeof(ColumnAttribute), false)) return propertyName;
ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
return ca.ColumnName;
}
public static Type GetColumnType(Type type, string propertyName)
{
PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
if (pi == null) throw new Exception("");
if (!pi.IsDefined(typeof(ColumnAttribute), false)) return pi.PropertyType;
ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
return SwithType(ca.ColumnType);
}
static Type SwithType(DataType dtype)
{
Type type = null;
switch (dtype)
{
case DataType.String:
type = typeof(String);
break;
case DataType.Int:
type = typeof(Int32);
break;
case DataType.DateTime:
type = typeof(DateTime);
break;
case DataType.Float:
type = typeof(float);
break;
case DataType.Double:
type = typeof(double);
break;
}
return type;
}
}
}
自定义TableAttribute
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace SimpleLinq2Sql
{
[AttributeUsage(AttributeTargets.Class)]
internal class TableAttribute : Attribute
{
private string _TableName;
public string TableName { get { return _TableName; } }
public TableAttribute(string tableName)
{
_TableName = tableName;
}
}
}
自定义ColumnAttribute
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace SimpleLinq2Sql
{
[AttributeUsage(AttributeTargets.Property)]
internal class ColumnAttribute : Attribute
{
private string _ColumnName;
private DataType _ColumnType = DataType.String;
public string ColumnName { get { return _ColumnName; } }
public DataType ColumnType
{
get { return _ColumnType; }
set { _ColumnType = value; }
}
public ColumnAttribute(string columnName)
{
_ColumnName = columnName;
}
}
}
数据类型枚举
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace SimpleLinq2Sql
{
public enum DataType
{
Int,
String,
Float,
Double,
DateTime
}
}
Table转换对应实体类
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data;
using System.Reflection;
namespace SimpleLinq2Sql
{
internal static class Table2Entity
{
static object ConvertFromDataRow(DataRow dr, Type type)
{
object o = null;
if (!type.IsDefined(typeof(TableAttribute), false))
{
List<object> paralist = new List<object>();
PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
foreach (PropertyInfo p in pi)
{
if (!dr.Table.Columns.Contains(p.Name))
throw new Exception("");
object value = Convert.ChangeType(dr[p.Name], p.PropertyType);
paralist.Add(value);
}
o = Activator.CreateInstance(type, paralist.ToArray());
}
else
{
o = Activator.CreateInstance(type);
PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
foreach (PropertyInfo p in pi)
{
if (!dr.Table.Columns.Contains(MapHelper.GetColumnName(type, p.Name)))
throw new Exception("");
object value = Convert.ChangeType(dr[MapHelper.GetColumnName(type, p.Name)], p.PropertyType);
p.SetValue(o, value, null);
}
}
return o;
}
public static object ConvertFromTable(DataTable dt, Type type)
{
var t = typeof(List<>).MakeGenericType(type);
object obj = Activator.CreateInstance(t);
MethodInfo add = t.GetMethod("Add");
foreach (DataRow dr in dt.Rows)
{
add.Invoke(obj, new object[] { ConvertFromDataRow(dr, type) });
}
return obj;
}
}
}
自定义的实体类:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace SimpleLinq2Sql
{
[Table("Student")]
public class Student
{
private int _ID;
private string _StuName;
private string _Address;
private int _Sex;
private int _CollegeID;
[Column("ID", ColumnType = DataType.Int)]
public int ID
{
get { return _ID; }
set { _ID = value; }
}
[Column("StuName")]
public string Name
{
get { return _StuName; }
set { _StuName = value; }
}
[Column("Address")]
public string Address
{
get { return _Address; }
set { _Address = value; }
}
[Column("Sex", ColumnType = DataType.Int)]
public int Sex
{
get { return _Sex; }
set { _Sex = value; }
}
[Column("CollegeID", ColumnType = DataType.Int)]
public int CollegeID
{
get { return _CollegeID; }
set { _CollegeID = value; }
}
}
}
Program执行
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace SimpleLinq2Sql
{
class Program
{
static void Main(string[] args)
{
var o = new CustomTable<Student>().Where(r => r.Address == "china").Select(r => new { NewName = r.Name, Country = r.Address, r.CollegeID, r.Sex })
.Where(r => r.Sex == 1);
Console.WriteLine(o.ToString());
foreach (var i in o)
{
Console.WriteLine(i.NewName + "," + i.Country + "," + i.CollegeID + "," + i.Sex);
}
Console.Read();
}
}
}