简单的实现一个自定义的Linq to Sql Provider

这两天空闲时间研究了一下Linq 的提供器,简单的实现了一下,代码写的很乱,也没有注释,也没怎么对代码进行设计,因此有很多的临时变量和有些不必要的操作,但注重的是实现原理吧,微软的Linq to SQL实现水很深,这个例子只是简单的实现select和where,其他的没有实现,并且对于where查询,只支持有限的==、>、<,不过这个不重要,如果需要可以添加对应的实现

先把代码记录下来吧,以后有时间再优化下代码和添加些注释。


IQueryable的实现:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Collections;

namespace SimpleLinq2Sql
{
    public class CustomTable<T> : IQueryable<T>
    {
        private Type _ElementType = null;
        private Expression _Expression = null;
        private IQueryProvider _Provider = null;

        public Type ElementType
        {
            get { return _ElementType; }
        }
        public Expression Expression
        {
            get { return _Expression; }
        }
        public IQueryProvider Provider
        {
            get { return _Provider; }
        }

        public CustomTable(Expression expression, IQueryProvider provider)
        {
            if (provider == null)
                throw new Exception("provider can't be null");
            _ElementType = typeof(T);
            _Expression = expression;
            _Provider = provider;
        }
        public CustomTable()
            : this(null, new CustomProvider())
        {
            _Expression = Expression.Constant(this);
        }

        public IEnumerator<T> GetEnumerator()
        {
            return (Provider.Execute<IEnumerable<T>>(Expression)).GetEnumerator();
        }
        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
        public override string ToString()
        {
            return _Provider.ToString();
        }
    }
}

IQueryProvider的实现:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using System.Data;
using System.Data.SqlClient;

namespace SimpleLinq2Sql
{
    public class CustomProvider : IQueryProvider
    {
        private string sql = "";
        private int count = 0;

        private string tableName = "";
        private string selector = "";
        private string where = "";

        private Type _PreType = null;
        private Type _ElementType = null;
        public IQueryable<T> CreateQuery<T>(Expression expression)
        {
            _ElementType = typeof(T);
            SetQueryText(expression);
            count++;
            return new CustomTable<T>(expression, this);
        }
        public IQueryable CreateQuery(Expression expression)
        {
            _ElementType = expression.Type.GetGenericArguments()[0];
            SetQueryText(expression);
            count++;
            object[] args = new object[] { expression, this };
            return (IQueryable)Activator.CreateInstance(typeof(CustomTable<>).MakeGenericType(_ElementType), args);
        }

        public T Execute<T>(Expression expression)
        {
            return (T)ExecuteSql(expression);
        }
        public object Execute(Expression expression)
        {
            return ExecuteSql(expression);
        }
        private void SetQueryText(Expression expression)
        {
            MethodCallExpression call = (MethodCallExpression)expression;
            Expression first = call.Arguments[0];
            Expression second = call.Arguments[1];
            SetTableName(first);
            if (call.Method.Name == "Select")
            {
                where = " ";
            }
            else if (call.Method.Name == "Where")
            {
                selector = "select " + "t" + count + ".*  ";
            }
            ProcessExpression(second);

            sql = selector + " from " + tableName + " " + where;
        }
        private void SetTableName(Expression expression)
        {
            if (expression is ConstantExpression)
            {
                _PreType = expression.Type.GetGenericArguments()[0];
                tableName = MapHelper.GetTableName(_PreType) + " as t" + count + " ";
            }
            if (expression is MethodCallExpression)
            {
                _PreType = expression.Type.GetGenericArguments()[0];
                tableName = "( " + sql + " ) as t" + count + " ";
            }
        }
        void ProcessExpression(Expression expression)
        {
            if (expression is UnaryExpression)
            {
                UnaryExpression tmp = (UnaryExpression)expression;
                ProcessExpression(tmp.Operand);
            }
            if (expression is LambdaExpression)
            {
                ProcessExpression(((LambdaExpression)expression).Body);
            }
            if (expression is BinaryExpression)
            {
                ProcessBinary((BinaryExpression)expression);
            }
            if (expression is NewExpression)
            {
                ProcessNew((NewExpression)expression);
            }
        }
        void ProcessBinary(BinaryExpression expression)
        {
            string membername = "";
            string propertyname = "";
            object value = "";
            string ope = "";
            if (expression.Left is BinaryExpression || expression.Right is BinaryExpression)
            {
                throw new Exception("only be one binary");
            }
            if (expression.Left is MemberExpression)
            {
                MemberExpression tmp = (MemberExpression)expression.Left;
                propertyname = tmp.Member.Name;
                membername = MapHelper.GetColumnName(_PreType, propertyname);
            }
            if (expression.Right is ConstantExpression)
            {
                ConstantExpression tmp = (ConstantExpression)expression.Right;
                value = tmp.Value;
            }
            if (expression.NodeType == ExpressionType.Equal)
            {
                ope = " = ";         
            }
            if (expression.NodeType == ExpressionType.LessThan)
            {
                ope = " < ";       
            }
            if (expression.NodeType == ExpressionType.GreaterThan)
            {
                ope = " > ";
            }
            Type type = MapHelper.GetColumnType(_PreType, propertyname);
            switch (type.Name)
            {
                case "Int32":
                case "Single":
                case "Double":
                    where += " where t" + count + "." + membername + ope + value;
                    break;
                case "String":
                case "DateTime":
                    where += " where t" + count + "." + membername + ope + "'" + value + "'";
                    break;

            }
        }
        void ProcessNew(NewExpression expression)
        {
            selector = "select ";
            List<string> newName = new List<String>();
            List<string> oldName = new List<string>();
            foreach (MemberInfo mi in expression.Members)
            {
                newName.Add(mi.Name);
            }
            foreach (MemberExpression arg in expression.Arguments)
            {
                oldName.Add(arg.Member.Name);
            }
            for (int i = 0; i < oldName.Count; i++)
            {
                if (newName[i] == oldName[i])
                {
                    selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + ",";
                }
                else
                {
                    selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + " as " + newName[i] + " ,";
                }
            }
            selector = selector.Substring(0, selector.Length - 1);
        }
        private object ExecuteSql(Expression expression)
        {
            DataSet ds = new DataSet();
            using (SqlConnection connection = new SqlConnection("Data Source=.;Initial Catalog=TestLinq;Integrated Security=True")) //这里写死了数据库连接
            {
                connection.Open();
                SqlCommand cmd = new SqlCommand(sql, connection);
                SqlDataAdapter da = new SqlDataAdapter(cmd);
                da.Fill(ds);
            }
            return Table2Entity.ConvertFromTable(ds.Tables[0], _ElementType); ;
        }
        public override string ToString()
        {
            return sql;
        }
    }
}

实体、属性与数据库中的表、列映射帮助类

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Reflection;

namespace SimpleLinq2Sql
{
    public static class MapHelper
    {
        public static string GetTableName(Type type)
        {
            if (!type.IsDefined(typeof(TableAttribute), false)) throw new Exception("");
            TableAttribute ta = Attribute.GetCustomAttribute(type, typeof(TableAttribute)) as TableAttribute;
            return ta.TableName;

        }
        public static string GetColumnName(Type type, string propertyName)
        {
            PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
            if (pi == null) throw new Exception("");
            if (!pi.IsDefined(typeof(ColumnAttribute), false)) return propertyName;

            ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
            return ca.ColumnName;
        }
        public static Type GetColumnType(Type type, string propertyName)
        {
            PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
            if (pi == null) throw new Exception("");
            if (!pi.IsDefined(typeof(ColumnAttribute), false)) return pi.PropertyType;

            ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
            return SwithType(ca.ColumnType);
        }
        static Type SwithType(DataType dtype)
        {
            Type type = null;
            switch (dtype)
            {
                case DataType.String:
                    type = typeof(String);
                    break;
                case DataType.Int:
                    type = typeof(Int32);
                    break;
                case DataType.DateTime:
                    type = typeof(DateTime);
                    break;
                case DataType.Float:
                    type = typeof(float);
                    break;
                case DataType.Double:
                    type = typeof(double);
                    break;
            }
            return type;
        }
    }
}

自定义TableAttribute

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace SimpleLinq2Sql
{
    [AttributeUsage(AttributeTargets.Class)]
    internal class TableAttribute : Attribute
    {
        private string _TableName;
        public string TableName { get { return _TableName; } }
        public TableAttribute(string tableName)
        {
            _TableName = tableName;
        }
    }
}

自定义ColumnAttribute

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace SimpleLinq2Sql
{
    [AttributeUsage(AttributeTargets.Property)]
    internal class ColumnAttribute : Attribute
    {
        private string _ColumnName;
        private DataType _ColumnType = DataType.String;
        public string ColumnName { get { return _ColumnName; } }
        public DataType ColumnType
        {
            get { return _ColumnType; }
            set { _ColumnType = value; }
        }
        public ColumnAttribute(string columnName)
        {
            _ColumnName = columnName;
        }
    }
}

数据类型枚举

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace SimpleLinq2Sql
{
    public enum DataType
    {
        Int,
        String,
        Float,
        Double,
        DateTime
    }
}

Table转换对应实体类

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data;
using System.Reflection;

namespace SimpleLinq2Sql
{
    internal static class Table2Entity
    {
        static object ConvertFromDataRow(DataRow dr, Type type)
        {
            object o = null;
            if (!type.IsDefined(typeof(TableAttribute), false))
            {
                List<object> paralist = new List<object>();
                PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
                foreach (PropertyInfo p in pi)
                {
                    if (!dr.Table.Columns.Contains(p.Name))
                        throw new Exception("");
                    object value = Convert.ChangeType(dr[p.Name], p.PropertyType);
                    paralist.Add(value);
                }
                o = Activator.CreateInstance(type, paralist.ToArray());
            }
            else
            {
                o = Activator.CreateInstance(type);
                PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
                foreach (PropertyInfo p in pi)
                {
                    if (!dr.Table.Columns.Contains(MapHelper.GetColumnName(type, p.Name)))
                        throw new Exception("");
                    object value = Convert.ChangeType(dr[MapHelper.GetColumnName(type, p.Name)], p.PropertyType);
                    p.SetValue(o, value, null);
                }
            }
            return o;
        }
        public static object ConvertFromTable(DataTable dt, Type type)
        {
            var t = typeof(List<>).MakeGenericType(type);
            object obj = Activator.CreateInstance(t);
            MethodInfo add = t.GetMethod("Add");
            foreach (DataRow dr in dt.Rows)
            {
                add.Invoke(obj, new object[] { ConvertFromDataRow(dr, type) });
            }
            return obj;
        }
    }
}

自定义的实体类:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace SimpleLinq2Sql
{
    [Table("Student")]
    public class Student
    {
        private int _ID;
        private string _StuName;
        private string _Address;
        private int _Sex;
        private int _CollegeID;

        [Column("ID", ColumnType = DataType.Int)]
        public int ID
        {
            get { return _ID; }
            set { _ID = value; }
        }
        [Column("StuName")]
        public string Name
        {
            get { return _StuName; }
            set { _StuName = value; }
        }
        [Column("Address")]
        public string Address
        {
            get { return _Address; }
            set { _Address = value; }
        }
        [Column("Sex", ColumnType = DataType.Int)]
        public int Sex
        {
            get { return _Sex; }
            set { _Sex = value; }
        }
        [Column("CollegeID", ColumnType = DataType.Int)]
        public int CollegeID
        {
            get { return _CollegeID; }
            set { _CollegeID = value; }
        }
    }
}

Program执行

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace SimpleLinq2Sql
{
    class Program
    {
        static void Main(string[] args)
        {
            var o = new CustomTable<Student>().Where(r => r.Address == "china").Select(r => new { NewName = r.Name, Country = r.Address, r.CollegeID, r.Sex })
                .Where(r => r.Sex == 1);
            Console.WriteLine(o.ToString());
            foreach (var i in o)
            {
                Console.WriteLine(i.NewName + "," + i.Country + "," + i.CollegeID + "," + i.Sex);
            }
            Console.Read();
        }
    }
}


 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值