本文和大家一起交流一下我们开发中最常用的ORM框架之一的mybatis,mybatis支持xml配置和注解配置的两种使用模式.本文不是介绍mybatis这个框架,而是实现一个类似于mybatis的框架的基础功能,从而更好的掌握mybatis的原理.用过mybatis的人应该都知道,我们只需要定义一个mapper类型的接口,但却不需要任何的实现?这个很基础但也是整个框架的核心功能,OK 接下来我们就一起分析并实现这样的一个简单框架。
分析我们要实现的难点:
1 既然只需要定义mapper接口,没有对应的实现,所以显然不能直接实例化的,那么我们如何调用呢?常用的解决方案可以从下面几种模式考虑
& 通过接口匿名内部类来实现
& 通过字节码技术来虚拟实现一个接口的实例
& 通过动态代理技术实现接口的调用
而我们会选择基于最后一种也就是基于JDK的动态代理技术实现,因为mybatis底层也是这样实现的.
2 sql里面的参数和注解里面的参数如何绑定?
3 sql里面的#号如何转换成?
4 根据当前sql语句,获取对应的操作类型?
5 截取sql语句中的where条件?、
OK 看一下项目整体结构:
annoation:定义常用的注解
功能分别如下:
package com.qb.orm.annoation; import java.lang.annotation.*; /** * @Author 18011618 * @Description 自定义删除注解 * @Date 9:26 2018/6/18 * @Modify By */ @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface OrmDelete { String value(); }
package com.qb.orm.annoation; import java.lang.annotation.*; /** * @Author 18011618 * @Description 自定义插入注解 * @Date 9:25 2018/6/18 * @Modify By */ @Documented @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface OrmInsert { String value(); }
package com.qb.orm.annoation; import java.lang.annotation.*; /** * @Author 18011618 * @Description 自定义查询注解 * @Date 9:27 2018/6/18 * @Modify By */ @Documented @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface OrmSelect { String value(); }
package com.qb.orm.annoation; import java.lang.annotation.*; /** * @Author 18011618 * @Description 自定义更新注解 * @Date 9:28 2018/6/18 * @Modify By */ @Documented @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface OrmUpdate { String value(); }
package com.qb.orm.annoation; import java.lang.annotation.*; /** * @Author 18011618 * @Description 自定义参数注解 * @Date 9:29 2018/6/18 * @Modify By */ @Documented @Target(ElementType.PARAMETER) @Retention(RetentionPolicy.RUNTIME) public @interface OrmParam { String value(); }
entity:定义一个user实体
package com.qb.orm.entity; /** * 用户实体类 */ public class User { private Integer id; private String userName; private Integer userAge; public String getUserName() { return userName; } public void setUserName(String userName) { this.userName = userName; } public Integer getUserAge() { return userAge; } public void setUserAge(Integer userAge) { this.userAge = userAge; } public Integer getId() { return id; } public void setId(Integer id) { this.id = id; } }
mapper:定义user对应的mapper接口
package com.qb.orm.mapper; import com.qb.orm.annoation.*; import com.qb.orm.entity.User; import java.util.List; /** * @Author 18011618 * @Description 用户数据访问层接口 * @Date 9:39 2018/6/18 * @Modify By */ public interface UserMapper { /** * @Author 18011618 * @Date 9:46 2018/6/18 * @Function 插入接口 */ @OrmInsert("insert into user(userName,userAge) values(#{userName},#{userAge})") int insertUser(@OrmParam(value = "userName") String userName, @OrmParam(value = "userAge") Integer userAge); /** * @Author 18011618 * @Date 9:46 2018/6/18 * @Function 删除接口 */ @OrmDelete("delete from user where id =#{id}") int deleteUser(@OrmParam(value = "id")Integer id); /** * @Author 18011618 * @Date 9:47 2018/6/18 * @Function 更新接口 */ @OrmUpdate("update user set userName =#{userName} where id = #{id}") int updateUser(@OrmParam(value = "userName")String userName,@OrmParam(value = "id") Integer id); /** * @Author 18011618 * @Date 9:47 2018/6/18 * @Function 查询接口 */ @OrmSelect("select * from User where userName=#{userName} and userAge=#{userAge}") List<User>selectUser(@OrmParam("userName") String name, @OrmParam("userAge") Integer userAge); }
util:定义工具类,主要有连接JDBC和对sql的操作,具体看代码
package com.qb.orm.util; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @Author 18011618 * @Date 10:41 2018/6/18 * @Function jdbc工具类 */ public final class JDBCUtil { private static String connect; private static String driverClassName; private static String URL; private static String username; private static String password; private static boolean autoCommit; /** 声明一个 Connection类型的静态属性,用来缓存一个已经存在的连接对象 */ private static Connection conn; static { config(); } /** * 开头配置自己的数据库信息 */ private static void config() { /* * 获取驱动 */ driverClassName = "com.mysql.jdbc.Driver"; /* * 获取URL */ URL = "jdbc:mysql://localhost:3306/test?useUnicode=true&characterEncoding=utf8"; /* * 获取用户名 */ username = "root"; /* * 获取密码 */ password = "root"; /* * 设置是否自动提交,一般为false不用改 */ autoCommit = false; } /** * 载入数据库驱动类 */ private static boolean load() { try { Class.forName(driverClassName); return true; } catch (ClassNotFoundException e) { System.out.println("驱动类 " + driverClassName + " 加载失败"); } return false; } /** * 专门检查缓存的连接是否不可以被使用 ,不可以被使用的话,就返回 true */ private static boolean invalid() { if (conn != null) { try { if (conn.isClosed() || !conn.isValid(3)) { return true; /* * isValid方法是判断Connection是否有效,如果连接尚未关闭并且仍然有效,则返回 true */ } } catch (SQLException e) { e.printStackTrace(); } /* * conn 既不是 null 且也没有关闭 ,且 isValid 返回 true,说明是可以使用的 ( 返回 false ) */ return false; } else { return true; } } /** * 建立数据库连接 */ public static Connection connect() { if (invalid()) { /* invalid为true时,说明连接是失败的 */ /* 加载驱动 */ load(); try { /* 建立连接 */ conn = DriverManager.getConnection(URL, username, password); } catch (SQLException e) { System.out.println("建立 " + connect + " 数据库连接失败 , " + e.getMessage()); } } return conn; } /** * 设置是否自动提交事务 **/ public static void transaction() { try { conn.setAutoCommit(autoCommit); } catch (SQLException e) { System.out.println("设置事务的提交方式为 : " + (autoCommit ? "自动提交" : "手动提交") + " 时失败: " + e.getMessage()); } } /** * 创建 Statement 对象 */ public static Statement statement() { Statement st = null; connect(); /* 如果连接是无效的就重新连接 */ transaction(); /* 设置事务的提交方式 */ try { st = conn.createStatement(); } catch (SQLException e) { System.out.println("创建 Statement 对象失败: " + e.getMessage()); } return st; } /** * 根据给定的带参数占位符的SQL语句,创建 PreparedStatement 对象 * * @param SQL * 带参数占位符的SQL语句 * @return 返回相应的 PreparedStatement 对象 */ private static PreparedStatement prepare(String SQL, boolean autoGeneratedKeys) { PreparedStatement ps = null; connect(); /* 如果连接是无效的就重新连接 */ transaction(); /* 设置事务的提交方式 */ try { if (autoGeneratedKeys) { ps = conn.prepareStatement(SQL, Statement.RETURN_GENERATED_KEYS); } else { ps = conn.prepareStatement(SQL); } } catch (SQLException e) { System.out.println("创建 PreparedStatement 对象失败: " + e.getMessage()); } return ps; } /** * 获取查询结果集 * @param SQL * @param params * @return */ public static ResultSet query(String SQL, List<Object> params) { if (SQL == null || SQL.trim().isEmpty() || !SQL.trim().toLowerCase().startsWith("select")) { throw new RuntimeException("你的SQL语句为空或不是查询语句"); } ResultSet rs = null; if (params.size() > 0) { /* 说明 有参数 传入,就需要处理参数 */ PreparedStatement ps = prepare(SQL, false); try { for (int i = 0; i < params.size(); i++) { ps.setObject(i + 1, params.get(i)); } rs = ps.executeQuery(); } catch (SQLException e) { System.out.println("执行SQL失败: " + e.getMessage()); } } else { /* 说明没有传入任何参数 */ Statement st = statement(); try { rs = st.executeQuery(SQL); // 直接执行不带参数的 SQL 语句 } catch (SQLException e) { System.out.println("执行SQL失败: " + e.getMessage()); } } return rs; } /** * 判断数据类型 * @param o * @return */ private static Object typeof(Object o) { Object r = o; if (o instanceof java.sql.Timestamp) { return r; } // 将 java.util.Date 转成 java.sql.Date if (o instanceof java.util.Date) { java.util.Date d = (java.util.Date) o; r = new java.sql.Date(d.getTime()); return r; } // 将 Character 或 char 变成 String if (o instanceof Character || o.getClass() == char.class) { r = String.valueOf(o); return r; } return r; } /** * 根据传过来的语句 判断具体的操作 * @param SQL * @param params * @return */ public static boolean execute(String SQL, Object... params) { if (SQL == null || SQL.trim().isEmpty() || SQL.trim().toLowerCase().startsWith("select")) { throw new RuntimeException("你的SQL语句为空或有错"); } boolean r = false; /* 表示 执行 DDL 或 DML 操作是否成功的一个标识变量 */ /* 获得 被执行的 SQL 语句的 前缀 */ SQL = SQL.trim(); SQL = SQL.toLowerCase(); String prefix = SQL.substring(0, SQL.indexOf(" ")); String operation = ""; // 用来保存操作类型的 变量 // 根据前缀 确定操作 switch (prefix) { case "create": operation = "create table"; break; case "alter": operation = "update table"; break; case "drop": operation = "drop table"; break; case "truncate": operation = "truncate table"; break; case "insert": operation = "insert :"; break; case "update": operation = "update :"; break; case "delete": operation = "delete :"; break; } if (params.length > 0) { // 说明有参数 PreparedStatement ps = prepare(SQL, false); Connection c = null; try { c = ps.getConnection(); } catch (SQLException e) { e.printStackTrace(); } try { for (int i = 0; i < params.length; i++) { Object p = params[i]; p = typeof(p); ps.setObject(i + 1, p); } ps.executeUpdate(); commit(c); r = true; } catch (SQLException e) { System.out.println(operation + " 失败: " + e.getMessage()); rollback(c); } } else { // 说明没有参数 Statement st = statement(); Connection c = null; try { c = st.getConnection(); } catch (SQLException e) { e.printStackTrace(); } // 执行 DDL 或 DML 语句,并返回执行结果 try { st.executeUpdate(SQL); commit(c); // 提交事务 r = true; } catch (SQLException e) { System.out.println(operation + " 失败: " + e.getMessage()); rollback(c); // 回滚事务 } } return r; } /* * * @param SQL 需要执行的 INSERT 语句 * * @param autoGeneratedKeys 指示是否需要返回由数据库产生的键 * * @param params 将要执行的SQL语句中包含的参数占位符的 参数值 * * @return 如果指定 autoGeneratedKeys 为 true 则返回由数据库产生的键; 如果指定 autoGeneratedKeys * 为 false 则返回受当前SQL影响的记录数目 */ public static int insert(String SQL, boolean autoGeneratedKeys, List<Object> params) { int var = -1; if (SQL == null || SQL.trim().isEmpty()) { throw new RuntimeException("你没有指定SQL语句,请检查是否指定了需要执行的SQL语句"); } // 如果不是 insert 开头开头的语句 if (!SQL.trim().toLowerCase().startsWith("insert")) { System.out.println(SQL.toLowerCase()); throw new RuntimeException("你指定的SQL语句不是插入语句,请检查你的SQL语句"); } // 获得 被执行的 SQL 语句的 前缀 ( 第一个单词 ) SQL = SQL.trim(); SQL = SQL.toLowerCase(); if (params.size() > 0) { // 说明有参数 PreparedStatement ps = prepare(SQL, autoGeneratedKeys); Connection c = null; try { c = ps.getConnection(); // 从 PreparedStatement 对象中获得 它对应的连接对象 } catch (SQLException e) { e.printStackTrace(); } try { for (int i = 0; i < params.size(); i++) { Object p = params.get(i); p = typeof(p); ps.setObject(i + 1, p); } int count = ps.executeUpdate(); if (autoGeneratedKeys) { // 如果希望获得数据库产生的键 ResultSet rs = ps.getGeneratedKeys(); // 获得数据库产生的键集 if (rs.next()) { // 因为是保存的是单条记录,因此至多返回一个键 var = rs.getInt(1); // 获得值并赋值给 var 变量 } } else { var = count; // 如果不需要获得,则将受SQL影像的记录数赋值给 var 变量 } commit(c); } catch (SQLException e) { System.out.println("数据保存失败: " + e.getMessage()); rollback(c); } } else { // 说明没有参数 Statement st = statement(); Connection c = null; try { c = st.getConnection(); // 从 Statement 对象中获得 它对应的连接对象 } catch (SQLException e) { e.printStackTrace(); } // 执行 DDL 或 DML 语句,并返回执行结果 try { int count = st.executeUpdate(SQL); if (autoGeneratedKeys) { // 如果企望获得数据库产生的键 ResultSet rs = st.getGeneratedKeys(); // 获得数据库产生的键集 if (rs.next()) { // 因为是保存的是单条记录,因此至多返回一个键 var = rs.getInt(1); // 获得值并赋值给 var 变量 } } else { var = count; // 如果不需要获得,则将受SQL影像的记录数赋值给 var 变量 } commit(c); // 提交事务 } catch (SQLException e) { System.out.println("数据保存失败: " + e.getMessage()); rollback(c); // 回滚事务 } } return var; } /** 提交事务 */ private static void commit(Connection c) { if (c != null && !autoCommit) { try { c.commit(); } catch (SQLException e) { e.printStackTrace(); } } } /** 回滚事务 */ private static void rollback(Connection c) { if (c != null && !autoCommit) { try { c.rollback(); } catch (SQLException e) { e.printStackTrace(); } } } /** * 释放资源 **/ public static void release(Object cloaseable) { if (cloaseable != null) { if (cloaseable instanceof ResultSet) { ResultSet rs = (ResultSet) cloaseable; try { rs.close(); } catch (SQLException e) { e.printStackTrace(); } } if (cloaseable instanceof Statement) { Statement st = (Statement) cloaseable; try { st.close(); } catch (SQLException e) { e.printStackTrace(); } } if (cloaseable instanceof Connection) { Connection c = (Connection) cloaseable; try { c.close(); } catch (SQLException e) { e.printStackTrace(); } } } } }
SQLUtil:主要是处理sql语句的辅助功能,具体看代码注释
package com.qb.orm.util; import java.util.ArrayList; import java.util.List; /** * SQL处理的工具类 * @author 18011618 * */ public class SQLUtil{ /** * @Author 18011618 * @Date 10:50 2018/6/18 * @Function 获取sql语句中value后面的参数 */ public static String[] sqlInsertParameter(String sql) { int startIndex = sql.indexOf("values"); int endIndex = sql.length(); String substring = sql.substring(startIndex + 6, endIndex).replace("(", "").replace(")", "").replace("#{", "") .replace("}", ""); String[] split = substring.split(","); return split; } /** * * 获取select 后面where语句 * @param sql * @return */ public static List<String> sqlSelectParameter(String sql) { int startIndex = sql.indexOf("where"); int endIndex = sql.length(); String substring = sql.substring(startIndex + 5, endIndex); String[] split = substring.split("and"); List<String> listArr = new ArrayList<>(); for (String string : split) { String[] sp2 = string.split("="); listArr.add(sp2[0].trim()); } return listArr; } /** * 将SQL语句的参数替换变为? * @param sql * @param parameterName * @return */ public static String parameQuestion(String sql, String[] parameterName) { for (int i = 0; i < parameterName.length; i++) { String string = parameterName[i]; sql = sql.replace("#{" + string + "}", "?"); } return sql; } public static String parameQuestion(String sql, List<String> parameterName) { for (int i = 0; i < parameterName.size(); i++) { String string = parameterName.get(i); sql = sql.replace("#{" + string + "}", "?"); } return sql; } }
session:定义实现与sql进行交互的sqlsession,主要是通过JDK的动态代理实现的
package com.qb.orm.session; import com.qb.orm.proxy.OrmInvocationHandler; import java.lang.reflect.Proxy; /** * @Author 18011618 * @Description 创建sqlsession会话 * @Date 9:48 2018/6/18 * @Modify By */ public class SqlSession { /** * @Author 18011618 * @Date 10:00 2018/6/18 * @Function 通过代理获取具体的目标对象 */ public static <T> T getMapper(Class clazz){ return (T)Proxy.newProxyInstance(clazz.getClassLoader(),new Class []{clazz},new OrmInvocationHandler(clazz)); } }
proxy:具体实现mapper的动态代理调用功能,最终还是调用JDBCUtil里面对应的方法
package com.qb.orm.proxy; import com.qb.orm.annoation.OrmDelete; import com.qb.orm.annoation.OrmInsert; import com.qb.orm.annoation.OrmSelect; import com.qb.orm.annoation.OrmUpdate; import com.qb.orm.proxy.handler.AnnoationHandler; import java.lang.annotation.Annotation; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; /** * @Author 18011618 * @Description 执行mapper接口的handler * @Date 9:51 2018/6/18 * @Modify By */ public class OrmInvocationHandler implements InvocationHandler { private Object targetObject; public OrmInvocationHandler(Object targetObject){ this.targetObject = targetObject; } /** * 执行具体的拦截逻辑处理 * @param proxy * @param method * @param args * @return * @throws Throwable */ public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { //判断是否有插入注解 boolean insertFalg = method.isAnnotationPresent(OrmInsert.class); if (insertFalg){ OrmInsert ormInsert = method.getDeclaredAnnotation(OrmInsert.class); if (ormInsert!=null){ return AnnoationHandler.handleOrmInsert(ormInsert,method,args); } } //判断是否有查询注解 boolean selectFalg = method.isAnnotationPresent(OrmSelect.class); if (selectFalg){ OrmSelect ormSelect = method.getDeclaredAnnotation(OrmSelect.class); if (ormSelect!=null){ return AnnoationHandler.handleOrmSelect(ormSelect,proxy,method,args); } } return null; } }
上面只是实现了增加和查询,删除或者修改完全是类似的,可以自己直接扩展,具体的实现功能是在AnnoationHandler这个类里面,下面具体看这个类的相关代码:
package com.qb.orm.proxy.handler; import com.qb.orm.annoation.OrmInsert; import com.qb.orm.annoation.OrmParam; import com.qb.orm.annoation.OrmSelect; import com.qb.orm.util.JDBCUtil; import com.qb.orm.util.SQLUtil; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.sql.ResultSet; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; /** * @Author 18011618 * @Description 处理具体的各种注解的 * @Date 11:01 2018/6/18 * @Modify By */ public class AnnoationHandler { /** * 处理插入操作 * @param extInsert * @param method * @param args * @return */ public static Object handleOrmInsert(OrmInsert extInsert, Method method, Object[] args) { //获取注解上面的sql语句 String insertSql = extInsert.value(); //方法里面的参数和注解里面的参数进行绑定 ConcurrentHashMap<Object, Object> paramsMap = paramsMap( method, args); // 存放sql执行的参数---参数绑定过程 String[] sqlInsertParameter = SQLUtil.sqlInsertParameter(insertSql); List<Object> sqlParams = sqlParams(sqlInsertParameter, paramsMap); // 4. 根据参数替换参数变为? String newSQL = SQLUtil.parameQuestion(insertSql, sqlInsertParameter); System.out.println("newSQL:" + newSQL + ",sqlParams:" + sqlParams.toString()); // 5. 调用jdbc底层代码执行语句 return JDBCUtil.insert(newSQL, false, sqlParams); } /** * 处理查询操作 * @param ormSelect * @param proxy * @param method * @param args * @return */ public static Object handleOrmSelect(OrmSelect ormSelect,Object proxy,Method method,Object[] args){ String selectSql = ormSelect.value(); ConcurrentHashMap<Object, Object> paramsMap = paramsMap(method, args); List<String> sqlSelectParameter = SQLUtil.sqlSelectParameter(selectSql); List<Object> sqlParams = new ArrayList<>(); for (String parameterName : sqlSelectParameter) { Object parameterValue = paramsMap.get(parameterName); sqlParams.add(parameterValue); } String newSql = SQLUtil.parameQuestion(selectSql, sqlSelectParameter); System.out.println("newSQL:" + newSql + ",sqlParams:" + sqlParams.toString()); ResultSet res = JDBCUtil.query(newSql, sqlParams); try { if (!res.next()){ return null; } // 下标往上移动移位 res.previous(); //获取方法的返回类型 Class<?> returnType = method.getReturnType(); Object obj = returnType.newInstance(); while (res.next()) { // 获取当前所有的属性 Field[] declaredFields = returnType.getDeclaredFields(); for (Field field : declaredFields) { String fieldName = field.getName(); Object fieldValue = res.getObject(fieldName); field.setAccessible(true); field.set(obj, fieldValue); } } return obj; } catch (Exception e){ } return null; } /** * 转换成对应的参数 * @param sqlInsertParameter * @param paramsMap * @return */ private static List<Object> sqlParams(String[] sqlInsertParameter, ConcurrentHashMap<Object, Object> paramsMap) { List<Object> sqlParams = new ArrayList<>(); for (String paramName : sqlInsertParameter) { Object paramValue = paramsMap.get(paramName); sqlParams.add(paramValue); } return sqlParams; } /** * 方法里面的参数和注解里面的参数进行绑定 * @param method * @param args * @return */ private static ConcurrentHashMap<Object, Object> paramsMap(Method method, Object[] args) { ConcurrentHashMap<Object, Object> paramsMap = new ConcurrentHashMap<>(); // 获取方法上的参数 Parameter[] parameters = method.getParameters(); for (int i = 0; i < parameters.length; i++) { Parameter parameter = parameters[i]; OrmParam ormParam = parameter.getDeclaredAnnotation(OrmParam.class); if (ormParam != null) { // 参数名称 String paramName = ormParam.value(); Object paramValue = args[i]; paramsMap.put(paramName, paramValue); } } return paramsMap; } }
ok,代码的注释都很清楚,事实我们一个简单的orm框架就完成了,本质并没有多难,只要能够把问题一步步分析清楚,然后去解决它即可,下面看一个简单的测试类:
package com.qb.orm.test; import com.qb.orm.entity.User; import com.qb.orm.mapper.UserMapper; import com.qb.orm.session.SqlSession; import java.util.List; /** * @Author 18011618 * @Description 测试ORM功能 * @Date 10:03 2018/6/18 * @Modify By */ public class OrmTest { public static void main(String[] args) { UserMapper userMapper = SqlSession.getMapper(UserMapper.class); List<User> users = userMapper.selectUser("",0); for (User user:users) { System.out.println(user.getUserName()+":" + user.getUserAge() +":" + user.getId()); } } }