C#手把手教你写一个自己的ORM(一)

C#手把手教你写一个自己的ORM(一)

前言:

  网上的ORM框架很多,比如Freesql、sqlsugar、EF等等,在一个新项目中,我们可以直接引入使用;但是有个问题,如果你接手的是一个老项目,实体不能动,代码不能给人家搞乱,那你该怎么办呢?
  这个时候,我们就需要自己了解一下ORM的原理,然后自己写一个简易的,符合自己公司框架逻辑的ORM进行使用了。
  下面我会从浅到深,讲解一下ORM的基本实现方法。

ORM框架是什么

  对象关系映射(Object Relational Mapping),目前数据库是关系型数据库 , ORM 主要是把数据库中的关系数据映射成为程序中的对象。ORM提供了实现持久化层的另一种模式,它采用映射元数据来描述对象关系的映射,使得ORM中间件能在任何一个应用的业务逻辑层和数据库层之间充当桥梁。Java典型的ORM中间件有:Hibernate,Mybatis等。 这样能够让程序员更多的关注业务编程,而不用浪费在SQL语句的编写上。

ORM的方法论基于三个核心原则:

· 简单:以最基本的形式建模数据。
· 传达性:数据库结构被任何人都能理解的语言文档化。
· 精确性:基于数据模型创建正确标准化了的结构。

实现方式

实现方式分两种,看自己的需求。
1、基础版本:通过反射、属性来进行解析ORM表达式。(本文讲这个,进阶版本看下一文章)
2、进阶版本:跟市面上主流OMR一样,反射、属性+linq表达式解析,进行简易数据库操作。

实现代码

构造函数使用单例模式

		//构造函数私有化传入数据库链接字符串
        private ORM(string connectString)
        {
            _connectString = connectString;
        }
        
        /// <summary>
        /// 返回一个数据库连接实例
        /// </summary>
        /// <param name="connectString">连接字符串</param>
        /// <returns></returns>
        public static ORM CreateInstanse(string connectString) => new ORM(connectString.Replace("\"", ""));

Attributes属性定义,主要标明表名和主键名字

	// 表名
    public class TableNameAttribute : Attribute
    {
        public object _value { get; private set; }

        public TableNameAttribute(string tableName)
        {
            this._value = tableName;
        }
    }
    // 主键名
    public class PrimaryKeyAttribute : Attribute
    {
        public object _value { get; private set; }

        public PrimaryKeyAttribute(string primaryKey)
        {
            this._value = primaryKey;
        }
    }
	// 使用方式(例)
    /// <summary>
    /// 学生类
    /// </summary>
    [TableName("MapRecord")]
    [PrimaryKey("ID")]
    public class Student
    {
        // 身份ID
        public int Id { get; set; }
        // 名字
        public string Name { get; set; }
        // 长度
        public long Length { get; set; }
    }

对外方法

   #region 对外方法

        /// <summary>
        /// 使用ado.net的方法查询list
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <param name="connectionString"></param>
        /// <returns></returns>
        public List<T> QueryList<T>(string sql) where T : new() => GetQueryList<T>(sql);

        /// <summary>
        /// 使用ado.net的方法分页查询list
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <param name="connectionString"></param>
        /// <returns></returns>
        public List<T> QueryListPage<T>(string sql, Expression<Func<T, object>> orderEx, long pageIndex = 1, long pageSize = 10) where T : new()
        {
            var result = new List<T>();
            var orderObj = orderEx.GetExpressionPropertyNames();
            var pageSql = $@"select * from (select ROW_NUMBER() OVER(ORDER BY {string.Join(", ", orderObj)} desc) rn, t.* from ({sql}) t) t1 where t1.rn > {(pageIndex - 1) * pageSize} and t1.rn <= {pageIndex * pageSize}";

            result = GetQueryList<T>(pageSql);

            return result;
        }

        /// <summary>
        /// 返回该表行数
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <returns></returns>
        public int Count(string sql)
        {
            sql = $@"select count(1) Count from ({sql}) t";

            var result = QueryFirst<Common>(sql);

            return result.Count;
        }

        /// <summary>
        /// 使用ado.net的方法分页查询list
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <param name="connectionString"></param>
        /// <returns></returns>
        public static List<T> Delete<T>(string sql, string connectionString, Expression<Func<T, object>> orderEx, int pageIndex = 1, int pageSize = 10) where T : new()
        {
            var result = new List<T>();
            var orderObj = orderEx.GetExpressionPropertyNames();
            var pageSql = $@"select * from (select ROW_NUMBER() OVER(ORDER BY {string.Join(", ", orderObj)}) rn, t.* from ({sql}) t) t1 where t1.rn > {(pageIndex - 1) * pageSize} and t1.rn <= {pageIndex * pageSize}";

            using (SqlConnection connection = new SqlConnection(connectionString))
            {
                DataTable dt = new DataTable();
                try
                {
                    connection.Open();
                    using (SqlDataAdapter command = new SqlDataAdapter(pageSql, connection))
                    {
                        command.Fill(dt);
                        result = dt.ToList<T>();
                    }
                }
                catch (System.Data.SqlClient.SqlException ex)
                {
                    Logger.Error(ex);
                }
            }

            return result;
        }

        /// <summary>
        /// 批量插入 返回受影响的行数
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="list"></param>
        /// <returns></returns>
        public int InsertList<T>(List<T> list) where T : new()
        {
            var sql = $@"";
            var properties = typeof(T).GetProperties().ToList();

            var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
            var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();

            foreach (var item in list)
            {
                var keyProperty = properties.FirstOrDefault(t => t.Name == key);
                var keyValue = keyProperty.GetValue(item);
                if (keyValue == null || keyValue.ToString().IsNullOrEmpty() || keyValue.ToString() == "0")
                {
                    properties.Remove(keyProperty);
                }

                var columns = properties.Select(t => t.Name).ToList();

                var values = properties.Select(t => t.GetValue(item) ?? "NULL").ToList();


                sql += $@"insert into {table}({string.Join(", ", columns)}) values('{string.Join("', '", values)}'); 
                            ".Replace("'NULL'", "NULL");
            }

            return ExecuteNonQuery(sql);
        }

        /// <summary>
        /// 批量插入 返回受影响的行数
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="item"></param>
        /// <returns></returns>
        public int Insert<T>(T item) where T : new()
        {
            var sql = $@"";
            var properties = typeof(T).GetProperties().ToList();

            var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
            var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();

            var keyProperty = properties.FirstOrDefault(t => t.Name == key);
            var keyValue = keyProperty.GetValue(item);
            if (keyValue == null || keyValue.ToString().IsNullOrEmpty() || keyValue.ToString() == "0")
            {
                properties.Remove(keyProperty);
            }

            var columns = properties.Select(t => t.Name).ToList();

            var values = properties.Select(t => t.GetValue(item) ?? "NULL").ToList();


            sql += $@"insert into {table}({string.Join(", ", columns)}) values('{string.Join("', '", values)}'); 
                            ".Replace("'NULL'", "NULL");

            return ExecuteNonQuery(sql);
        }

        /// <summary>
        /// 更新单个实体 不为空的字段则更新
        /// </summary>
        /// <typeparam name="T">更新表</typeparam>
        /// <param name="item"></param>
        /// <returns></returns>
        public int Update<T>(T item)
        {
            var updates = new List<string>();
            var properties = item.GetType().GetProperties().ToList();
            var tableProperties = typeof(T).GetProperties().ToList();
            var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
            var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();

            var keyProperty = properties.FirstOrDefault(t => t.Name == key);
            var tag = properties.Remove(keyProperty);

            foreach (var property in properties)
            {
                var value = property.GetValue(item);
                if (value != null && value.ToString().IsNotNull() && tableProperties.Any(t => t.Name == property.Name))
                {
                    updates.Add($"{property.Name} = '{value}'");
                }
            }

            var sql = $@"update {table} set {string.Join(", ", updates)} where {key} = '{keyProperty.GetValue(item)}';";

            return ExecuteNonQuery(sql);
        }

        /// <summary>
        /// 更新单个实体 不为空的字段则更新
        /// </summary>
        /// <typeparam name="T">更新表</typeparam>
        /// <param name="list"></param>
        /// <returns></returns>
        public int UpdateList<T>(List<T> list)
        {
            var properties = typeof(T).GetProperties().ToList();
            var tableProperties = typeof(T).GetProperties().ToList();
            var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
            var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();

            var keyProperty = properties.FirstOrDefault(t => t.Name == key);
            var tag = properties.Remove(keyProperty);

            var sql = string.Empty;

            foreach (var item in list)
            {
                var updates = new List<string>();
                foreach (var property in properties)
                {
                    var value = property.GetValue(list);
                    if (value != null && value.ToString().IsNotNull() && tableProperties.Any(t => t.Name == property.Name))
                    {
                        updates.Add($"{property.Name} = '{value}'");
                    }
                }
                sql += $@"update {table} set {string.Join(", ", updates)} where {key} = '{keyProperty.GetValue(list)}';";
            }

            return ExecuteNonQuery(sql);
        }

        #region 转Linq 下一篇文章讲解这个

        /// <summary>
        /// 转到linq查询 返回一个新的查询实例
        /// </summary>
        /// <typeparam name="T1"></typeparam>
        /// <returns></returns>
        public ISelect<T1> Select<T1>() where T1: new () => new SelectProvider<T1>(this);

        #endregion

        #endregion

对内方法

#region 私有方法

        /// <summary>
        /// 执行sql 返回受影响的行数
        /// </summary>
        /// <param name="sql"></param>
        /// <returns></returns>
        private int ExecuteNonQuery(string sql)
        {
            var effectRows = 0;
            using (SqlConnection connection = new SqlConnection(_connectString))
            {
                connection.Open();
                SqlTransaction transaction = connection.BeginTransaction();
                using (SqlCommand cmd = connection.CreateCommand())
                {
                    cmd.Connection = connection;
                    cmd.Transaction = transaction;

                    if (sql.Length > 0)
                    {
                        cmd.CommandText = sql;
                        cmd.CommandType = CommandType.Text;

                        try
                        {
                            effectRows = cmd.ExecuteNonQuery();
                            transaction.Commit();
                        }
                        catch (SqlException ex)
                        {
                            transaction.Rollback();
                            Logger.Error(ex);
                        }
                    }
                }

            }
            return effectRows;
        }

        /// <summary>
        /// 查询 返回列表
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <returns></returns>
        private List<T> GetQueryList<T>(string sql) where T : new()
        {
            List<T> result = new List<T>();

            using (SqlConnection connection = new SqlConnection(_connectString))
            {
                DataTable dt = new DataTable();
                try
                {
                    connection.Open();
                    using (SqlDataAdapter command = new SqlDataAdapter(sql, connection))
                    {
                        command.Fill(dt);
                        result = dt.ToList<T>();
                    }
                }
                catch (System.Data.SqlClient.SqlException ex)
                {
                    Logger.Error(ex);
                }
            }

            return result;
        }

        /// <summary>
        /// 查询 返回第一列
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <returns></returns>
        private T QueryFirst<T>(string sql) where T : new()
        {
            T result = new T();

            using (SqlConnection connection = new SqlConnection(_connectString))
            {
                DataTable dt = new DataTable();
                try
                {
                    connection.Open();
                    using (SqlDataAdapter command = new SqlDataAdapter(sql, connection))
                    {
                        command.Fill(dt);
                        result = dt.ToList<T>().FirstOrDefault();
                    }
                }
                catch (System.Data.SqlClient.SqlException ex)
                {
                    Logger.Error(ex);
                }
            }

            return result;
        }
		
		// 反射需要用到此实体
        private class Common
        {
            public int Count { get; set; }
        }
        

        /// <summary>
        /// 获取表达式属性列表
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="exp"></param>
        /// <returns></returns>
        public static List<string> GetExpressionPropertyNames<T>(this Expression<Func<T, object>> exp)
        {
            var result = SelectMembers(exp.Body).Select(t => t.Name).ToList();

            return result;
        }
        private static List<MemberInfo> SelectMembers(Expression properties)
        {
            if (properties is NewExpression newExp && newExp.Members != null)
            {
                return newExp.Members.ToList();
            }
            else
            {
                if (properties is NewArrayExpression newArr)
                {
                    List<MemberInfo> newArrMembers = new List<MemberInfo>();
                    foreach (var newArrExp in newArr.Expressions)
                        newArrMembers.AddRange(SelectMembers(newArrExp));
                    return newArrMembers.Distinct().ToList();
                }
                return null;
            }
        }
        
        #endregion

DataTable转list通用方法,小伙伴可以收藏哦


        /// <summary>
        /// datatable转list通用方法
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="table"></param>
        /// <returns></returns>
        public static List<T> ToList<T>(this DataTable table) where T : new()
        {
            var result = new List<T>();

            foreach (DataRow row in table.Rows)
            {
                T obj = new T();
                foreach (DataColumn column in table.Columns)
                {
                    var property = typeof(T).GetProperty(column.ColumnName);
                    if (property != null && row[column].ToString().IsNotNull())
                    {
                        try
                        {
                            property.SetValue(obj, row[column], null);
                        }
                        catch (Exception)
                        {
                            property.SetValue(obj, Convert.ChangeType(row[column], property.PropertyType), null);
                        }
                    }
                }
                result.Add(obj);
            }

            return result;
        }

使用方法

public class TestClass
{
	static readonly ORM db = ORM.CreateInstanse(Watch_Com.JsonHelper.AppSetting["ConnectionStrings:SoftDb"]); // 此处为连接字符串,需要源码的可以去我的json文章篇里有
	public void testFunc()
	{
		//插入数据
        var stu = new Student()
        {
            Id=1,
            Name="张三",
            Length=18
        };
        db.Insert(stu);
        
		//更新数据
        stu.Length++;
        db.Update(stu);
        
        // 查询列表
         var sql = $@"select * from Student";
         var stu1 = db.QueryList<Student>(sql);

         // 分页查询
         var sql = $@"select * from Student";
         var stu2 = db.QueryListPage<Student>(sql, t=> t.Id, 1, 10); //1为页码 10位页容量
	}
}

总结

  基础版的ORM多使用反射,进行属性值的判断,进行数据库操作,但是底层还是ado.net,具体需要实现什么样的功能,还需要按照自己的实际项目确定。
  具体的linq解析,表达式树解析版本的ORM框架搭建,可以看我的下一篇文章,应该会在近两天更新。

最后感谢大家的阅读,如有疑问,欢迎一起探讨!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值