C#手把手教你写一个自己的ORM(一)
前言:
网上的ORM框架很多,比如Freesql、sqlsugar、EF等等,在一个新项目中,我们可以直接引入使用;但是有个问题,如果你接手的是一个老项目,实体不能动,代码不能给人家搞乱,那你该怎么办呢?
这个时候,我们就需要自己了解一下ORM的原理,然后自己写一个简易的,符合自己公司框架逻辑的ORM进行使用了。
下面我会从浅到深,讲解一下ORM的基本实现方法。
ORM框架是什么
对象关系映射(Object Relational Mapping),目前数据库是关系型数据库 , ORM 主要是把数据库中的关系数据映射成为程序中的对象。ORM提供了实现持久化层的另一种模式,它采用映射元数据来描述对象关系的映射,使得ORM中间件能在任何一个应用的业务逻辑层和数据库层之间充当桥梁。Java典型的ORM中间件有:Hibernate,Mybatis等。 这样能够让程序员更多的关注业务编程,而不用浪费在SQL语句的编写上。
ORM的方法论基于三个核心原则:
· 简单:以最基本的形式建模数据。
· 传达性:数据库结构被任何人都能理解的语言文档化。
· 精确性:基于数据模型创建正确标准化了的结构。
实现方式
实现方式分两种,看自己的需求。
1、基础版本:通过反射、属性来进行解析ORM表达式。(本文讲这个,进阶版本看下一文章)
2、进阶版本:跟市面上主流OMR一样,反射、属性+linq表达式解析,进行简易数据库操作。
实现代码
构造函数使用单例模式
//构造函数私有化传入数据库链接字符串
private ORM(string connectString)
{
_connectString = connectString;
}
/// <summary>
/// 返回一个数据库连接实例
/// </summary>
/// <param name="connectString">连接字符串</param>
/// <returns></returns>
public static ORM CreateInstanse(string connectString) => new ORM(connectString.Replace("\"", ""));
Attributes属性定义,主要标明表名和主键名字
// 表名
public class TableNameAttribute : Attribute
{
public object _value { get; private set; }
public TableNameAttribute(string tableName)
{
this._value = tableName;
}
}
// 主键名
public class PrimaryKeyAttribute : Attribute
{
public object _value { get; private set; }
public PrimaryKeyAttribute(string primaryKey)
{
this._value = primaryKey;
}
}
// 使用方式(例)
/// <summary>
/// 学生类
/// </summary>
[TableName("MapRecord")]
[PrimaryKey("ID")]
public class Student
{
// 身份ID
public int Id { get; set; }
// 名字
public string Name { get; set; }
// 长度
public long Length { get; set; }
}
对外方法
#region 对外方法
/// <summary>
/// 使用ado.net的方法查询list
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <param name="connectionString"></param>
/// <returns></returns>
public List<T> QueryList<T>(string sql) where T : new() => GetQueryList<T>(sql);
/// <summary>
/// 使用ado.net的方法分页查询list
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <param name="connectionString"></param>
/// <returns></returns>
public List<T> QueryListPage<T>(string sql, Expression<Func<T, object>> orderEx, long pageIndex = 1, long pageSize = 10) where T : new()
{
var result = new List<T>();
var orderObj = orderEx.GetExpressionPropertyNames();
var pageSql = $@"select * from (select ROW_NUMBER() OVER(ORDER BY {string.Join(", ", orderObj)} desc) rn, t.* from ({sql}) t) t1 where t1.rn > {(pageIndex - 1) * pageSize} and t1.rn <= {pageIndex * pageSize}";
result = GetQueryList<T>(pageSql);
return result;
}
/// <summary>
/// 返回该表行数
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public int Count(string sql)
{
sql = $@"select count(1) Count from ({sql}) t";
var result = QueryFirst<Common>(sql);
return result.Count;
}
/// <summary>
/// 使用ado.net的方法分页查询list
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <param name="connectionString"></param>
/// <returns></returns>
public static List<T> Delete<T>(string sql, string connectionString, Expression<Func<T, object>> orderEx, int pageIndex = 1, int pageSize = 10) where T : new()
{
var result = new List<T>();
var orderObj = orderEx.GetExpressionPropertyNames();
var pageSql = $@"select * from (select ROW_NUMBER() OVER(ORDER BY {string.Join(", ", orderObj)}) rn, t.* from ({sql}) t) t1 where t1.rn > {(pageIndex - 1) * pageSize} and t1.rn <= {pageIndex * pageSize}";
using (SqlConnection connection = new SqlConnection(connectionString))
{
DataTable dt = new DataTable();
try
{
connection.Open();
using (SqlDataAdapter command = new SqlDataAdapter(pageSql, connection))
{
command.Fill(dt);
result = dt.ToList<T>();
}
}
catch (System.Data.SqlClient.SqlException ex)
{
Logger.Error(ex);
}
}
return result;
}
/// <summary>
/// 批量插入 返回受影响的行数
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="list"></param>
/// <returns></returns>
public int InsertList<T>(List<T> list) where T : new()
{
var sql = $@"";
var properties = typeof(T).GetProperties().ToList();
var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();
foreach (var item in list)
{
var keyProperty = properties.FirstOrDefault(t => t.Name == key);
var keyValue = keyProperty.GetValue(item);
if (keyValue == null || keyValue.ToString().IsNullOrEmpty() || keyValue.ToString() == "0")
{
properties.Remove(keyProperty);
}
var columns = properties.Select(t => t.Name).ToList();
var values = properties.Select(t => t.GetValue(item) ?? "NULL").ToList();
sql += $@"insert into {table}({string.Join(", ", columns)}) values('{string.Join("', '", values)}');
".Replace("'NULL'", "NULL");
}
return ExecuteNonQuery(sql);
}
/// <summary>
/// 批量插入 返回受影响的行数
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="item"></param>
/// <returns></returns>
public int Insert<T>(T item) where T : new()
{
var sql = $@"";
var properties = typeof(T).GetProperties().ToList();
var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();
var keyProperty = properties.FirstOrDefault(t => t.Name == key);
var keyValue = keyProperty.GetValue(item);
if (keyValue == null || keyValue.ToString().IsNullOrEmpty() || keyValue.ToString() == "0")
{
properties.Remove(keyProperty);
}
var columns = properties.Select(t => t.Name).ToList();
var values = properties.Select(t => t.GetValue(item) ?? "NULL").ToList();
sql += $@"insert into {table}({string.Join(", ", columns)}) values('{string.Join("', '", values)}');
".Replace("'NULL'", "NULL");
return ExecuteNonQuery(sql);
}
/// <summary>
/// 更新单个实体 不为空的字段则更新
/// </summary>
/// <typeparam name="T">更新表</typeparam>
/// <param name="item"></param>
/// <returns></returns>
public int Update<T>(T item)
{
var updates = new List<string>();
var properties = item.GetType().GetProperties().ToList();
var tableProperties = typeof(T).GetProperties().ToList();
var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();
var keyProperty = properties.FirstOrDefault(t => t.Name == key);
var tag = properties.Remove(keyProperty);
foreach (var property in properties)
{
var value = property.GetValue(item);
if (value != null && value.ToString().IsNotNull() && tableProperties.Any(t => t.Name == property.Name))
{
updates.Add($"{property.Name} = '{value}'");
}
}
var sql = $@"update {table} set {string.Join(", ", updates)} where {key} = '{keyProperty.GetValue(item)}';";
return ExecuteNonQuery(sql);
}
/// <summary>
/// 更新单个实体 不为空的字段则更新
/// </summary>
/// <typeparam name="T">更新表</typeparam>
/// <param name="list"></param>
/// <returns></returns>
public int UpdateList<T>(List<T> list)
{
var properties = typeof(T).GetProperties().ToList();
var tableProperties = typeof(T).GetProperties().ToList();
var key = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("key")).ConstructorArguments.FirstOrDefault().Value.ToString();
var table = typeof(T).CustomAttributes.FirstOrDefault(t => t.AttributeType.Name.ToLower().Contains("table")).ConstructorArguments.FirstOrDefault().Value.ToString();
var keyProperty = properties.FirstOrDefault(t => t.Name == key);
var tag = properties.Remove(keyProperty);
var sql = string.Empty;
foreach (var item in list)
{
var updates = new List<string>();
foreach (var property in properties)
{
var value = property.GetValue(list);
if (value != null && value.ToString().IsNotNull() && tableProperties.Any(t => t.Name == property.Name))
{
updates.Add($"{property.Name} = '{value}'");
}
}
sql += $@"update {table} set {string.Join(", ", updates)} where {key} = '{keyProperty.GetValue(list)}';";
}
return ExecuteNonQuery(sql);
}
#region 转Linq 下一篇文章讲解这个
/// <summary>
/// 转到linq查询 返回一个新的查询实例
/// </summary>
/// <typeparam name="T1"></typeparam>
/// <returns></returns>
public ISelect<T1> Select<T1>() where T1: new () => new SelectProvider<T1>(this);
#endregion
#endregion
对内方法
#region 私有方法
/// <summary>
/// 执行sql 返回受影响的行数
/// </summary>
/// <param name="sql"></param>
/// <returns></returns>
private int ExecuteNonQuery(string sql)
{
var effectRows = 0;
using (SqlConnection connection = new SqlConnection(_connectString))
{
connection.Open();
SqlTransaction transaction = connection.BeginTransaction();
using (SqlCommand cmd = connection.CreateCommand())
{
cmd.Connection = connection;
cmd.Transaction = transaction;
if (sql.Length > 0)
{
cmd.CommandText = sql;
cmd.CommandType = CommandType.Text;
try
{
effectRows = cmd.ExecuteNonQuery();
transaction.Commit();
}
catch (SqlException ex)
{
transaction.Rollback();
Logger.Error(ex);
}
}
}
}
return effectRows;
}
/// <summary>
/// 查询 返回列表
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <returns></returns>
private List<T> GetQueryList<T>(string sql) where T : new()
{
List<T> result = new List<T>();
using (SqlConnection connection = new SqlConnection(_connectString))
{
DataTable dt = new DataTable();
try
{
connection.Open();
using (SqlDataAdapter command = new SqlDataAdapter(sql, connection))
{
command.Fill(dt);
result = dt.ToList<T>();
}
}
catch (System.Data.SqlClient.SqlException ex)
{
Logger.Error(ex);
}
}
return result;
}
/// <summary>
/// 查询 返回第一列
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <returns></returns>
private T QueryFirst<T>(string sql) where T : new()
{
T result = new T();
using (SqlConnection connection = new SqlConnection(_connectString))
{
DataTable dt = new DataTable();
try
{
connection.Open();
using (SqlDataAdapter command = new SqlDataAdapter(sql, connection))
{
command.Fill(dt);
result = dt.ToList<T>().FirstOrDefault();
}
}
catch (System.Data.SqlClient.SqlException ex)
{
Logger.Error(ex);
}
}
return result;
}
// 反射需要用到此实体
private class Common
{
public int Count { get; set; }
}
/// <summary>
/// 获取表达式属性列表
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="exp"></param>
/// <returns></returns>
public static List<string> GetExpressionPropertyNames<T>(this Expression<Func<T, object>> exp)
{
var result = SelectMembers(exp.Body).Select(t => t.Name).ToList();
return result;
}
private static List<MemberInfo> SelectMembers(Expression properties)
{
if (properties is NewExpression newExp && newExp.Members != null)
{
return newExp.Members.ToList();
}
else
{
if (properties is NewArrayExpression newArr)
{
List<MemberInfo> newArrMembers = new List<MemberInfo>();
foreach (var newArrExp in newArr.Expressions)
newArrMembers.AddRange(SelectMembers(newArrExp));
return newArrMembers.Distinct().ToList();
}
return null;
}
}
#endregion
DataTable转list通用方法,小伙伴可以收藏哦
/// <summary>
/// datatable转list通用方法
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="table"></param>
/// <returns></returns>
public static List<T> ToList<T>(this DataTable table) where T : new()
{
var result = new List<T>();
foreach (DataRow row in table.Rows)
{
T obj = new T();
foreach (DataColumn column in table.Columns)
{
var property = typeof(T).GetProperty(column.ColumnName);
if (property != null && row[column].ToString().IsNotNull())
{
try
{
property.SetValue(obj, row[column], null);
}
catch (Exception)
{
property.SetValue(obj, Convert.ChangeType(row[column], property.PropertyType), null);
}
}
}
result.Add(obj);
}
return result;
}
使用方法
public class TestClass
{
static readonly ORM db = ORM.CreateInstanse(Watch_Com.JsonHelper.AppSetting["ConnectionStrings:SoftDb"]); // 此处为连接字符串,需要源码的可以去我的json文章篇里有
public void testFunc()
{
//插入数据
var stu = new Student()
{
Id=1,
Name="张三",
Length=18
};
db.Insert(stu);
//更新数据
stu.Length++;
db.Update(stu);
// 查询列表
var sql = $@"select * from Student";
var stu1 = db.QueryList<Student>(sql);
// 分页查询
var sql = $@"select * from Student";
var stu2 = db.QueryListPage<Student>(sql, t=> t.Id, 1, 10); //1为页码 10位页容量
}
}
总结
基础版的ORM多使用反射,进行属性值的判断,进行数据库操作,但是底层还是ado.net,具体需要实现什么样的功能,还需要按照自己的实际项目确定。
具体的linq解析,表达式树解析版本的ORM框架搭建,可以看我的下一篇文章,应该会在近两天更新。