由于学校大作业不允许使用框架,但是,JDBC代码冗余情况严重,每次操作都要频繁经历“6步”才能得到结果,对上层开发人员来说,过于冗余,所以,自己手动封装了mybatis注解框架(PS:由于解析XML文件过于复杂,其中,<if> <foreach>相互嵌套频繁出bug,暂时成功的只有注解式开发框架)水平有限,如有错误,请大家不吝啬指教。
首先,我们来看下封住完成的框架大致长什么样:
其中,最重要的就是util包了,下面也将重点介绍util包!
mybatis最核心的当然就是大名鼎鼎的动态代理模式,实现 InvocationHandler 接口中的 invoke() 方法,在 invoke 中,首先通过
Annotation[] annotations = method.getAnnotations();
获得所有被代理方法的所有注解,此处明显只有一个,所以
Annotation annotationType = method.getAnnotations();
如此便获得了方法上的注解(如果在方法上多个注解同时使用,大家可以稍作修改,其实也很简单)。
获得注解之后,我们当然需要判断注解类型,@Select @Update @Insert @Delete ,不同类型的注解对应不同类型的操作(我这里比较懒,直接用if else了,现在想想,当时其实可以用策略模式,可能效果更好)
if (annotationType instanceof Select) { // @Select 注解 System.out.println("Select"); Select annotionSelect = method.getAnnotation(Select.class); if (annotionSelect != null) { // @Select 注解 sql = annotionSelect.value()[0]; prepareHandler(method, args, sql); } // 调用新增方法 Object queryResult = crudHandler.query(parseSQL, (Statement) connectionMap.get("statement"), returnType, o, c); System.out.println("queryResult:" + queryResult); return queryResult; } else if (annotationType instanceof Update) { // @Update 注解 System.out.println("Update"); Update annotionUpdate = method.getAnnotation(Update.class); if (annotionUpdate != null) { sql = annotionUpdate.value()[0]; prepareHandler(method, args, sql); } // 调用修改方法 int updateResult = crudHandler.update(parseSQL, (Statement) connectionMap.get("statement")); return updateResult; } else if (annotationType instanceof Insert) { // @Insert 注解 System.out.println("Insert"); Insert annotionInsert = method.getAnnotation(Insert.class); if (annotionInsert != null) { sql = annotionInsert.value()[0]; prepareHandler(method, args, sql); } //调用查询方法 int insertResult = crudHandler.insert(parseSQL, (Statement) connectionMap.get("statement")); return insertResult; } else { // @Delete 注解 System.out.println("Delete"); Delete annotionDelete = method.getAnnotation(Delete.class); if (annotionDelete != null) { sql = annotionDelete.value()[0]; prepareHandler(method, args, sql); }
其中,prepareHandler(method, args, sql)是我自定义的方法,主要是用于直接 JDBC 之前解析 SQL语句
具体执行思路如下:
①检索 sql ,将 sql 中 #{} 换成参数值
②如果有String类型需要特殊处理,加上" ",否则 sql 报错
具体替换过程如下所示:
public static String parseSQL(String sql, Map<String, Object> nameArgsMap) { // 字符串 aegN拼接 int COUNT_INDEX = 0; String argN = "arg" + COUNT_INDEX; StringBuilder stringBuilder = new StringBuilder(); // sql 字符下标值 int index = 0; for (int i = 0; i < sql.length(); i++) { char c = sql.charAt(i); if (c == '#') { // sql 语句中解析到 # index = i + 1; // 解析 sql 下一个字符 char nextChar = sql.charAt(index); if (nextChar != '{') { // 下一个字符不是 {,则抛出 RuntimeException 异常 throw new RuntimeException(String.format("sql语句异常\n,#后面没有{\n sql语句为" + sql)); } // 下一个字符是 { // 将字符与参数匹配 StringBuilder argSB = new StringBuilder(); // 参数变量名 // 解析 sql,变量赋值给 sql 语句 i = parseSQLArgs(argSB, sql, index, argN); // ageN 计数器 + 1 COUNT_INDEX++; argN = "arg" + COUNT_INDEX; // 获得参数变量名,是 map 中的 key String argName = argSB.toString(); // 获得参数数据 Object argValue = nameArgsMap.get(argName); if (argValue == null) { // 参数为空,抛出异常 throw new RuntimeException(String.format("参数异常\n sql语句为:" + sql + "参数为空")); } stringBuilder.append(argValue.toString()); } else { // 不是 # 直接添加 stringBuilder.append(c); } } return stringBuilder.toString(); }
/** * sql 中有 { ,处理 {} 中参数数据 * @param argSB * @param sql * @param index * @param argN * @return */ private static int parseSQLArgs(StringBuilder argSB, String sql, int index, String argN) { index++; // 指向 { 的下一位 for (; index < sql.length(); index++) { char c = sql.charAt(index); // 大括号中有参数 argSB.append(argN); while (c != '}') { index++; c = sql.charAt(index); } return index; } throw new RuntimeException("缺少右括号异常\n sql:" + sql); }
备注比较明确,就不一一解释了,思想就是每个字符检索,遇到#则先判断后续是否合法,若不合法直接抛出异常,合法之后,再在 { 情况下判断是否合法,都合法的话,进行参数替换。
然后就到JDBC操作了!直接上代码,我们最终得到 Connection、statement、clazz对象,并封装到集合中返回,如果只返回statement对象,事物回滚有问题,必须得到本次的 Connection !
@Component public class JDBCConnection { @Value("${spring.datasource.driver-class-name}") private String driver; @Value("${spring.datasource.url}") private String url; @Value("${spring.datasource.username}") private String username; @Value("${spring.datasource.password}") private String password; Map<String, Object> connectionMap; public Map<String, Object> jdbcConnection() { // 读取 application.yml 文件流 Connection connection = null; Map<String, Object> connectionMap = new HashMap<>(); try { Class<?> clazz = Class.forName(driver); try { connection = DriverManager.getConnection(url, username, password); Statement statement = connection.createStatement(); // 不适用自动回滚机制 connection.setAutoCommit(false); // 将 JDBC 数据封装到集合中 connectionMap.put("connection", connection); connectionMap.put("statement", statement); connectionMap.put("clazz", clazz); } catch (SQLException e) { e.printStackTrace(); } } catch (ClassNotFoundException e) { e.printStackTrace(); } this.connectionMap = connectionMap; return connectionMap; } /** * 设置事物回滚 */ public void rollback() { Connection connection = (Connection) connectionMap.get("connection"); try { connection.rollback(); } catch (SQLException e) { e.printStackTrace(); } } /** * 关闭 jdbc 连接 */ public void close() throws SQLException { Connection connection = (Connection) connectionMap.get("connection"); Statement statement = (Statement) connectionMap.get("statement"); Class<?> clazz = (Class<?>) connectionMap.get("clazz"); statement.close(); connection.close(); } }
然后就是增删改查操作:其中,增删改基本一致,都是差不多的操作,为了方便扩展,我为每个操作写了一个具体实现:
@Override public int insert(String sql, Statement statement) throws Exception { System.out.println(sql); int insertResult = 0; try { insertResult = statement.executeUpdate(sql); if (insertResult >= 0) { // 插入数据库成功 // 关闭 JDBC 连接 jdbcConnection.close(); } else { // 插入出现异常 // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); } } catch (Exception e) { // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); throw new Exception("数据库插入异常"); } return insertResult; } @Override public int delete(String sql, Statement statement) throws Exception { System.out.println(sql); int deleteResult = 0; try { deleteResult = statement.executeUpdate(sql); if (deleteResult >= 0) { // 插入数据库成功 // 关闭 JDBC 连接 jdbcConnection.close(); } else { // 插入出现异常 // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); } } catch (Exception e) { // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); throw new Exception("数据库修改异常"); } return deleteResult; } @Override public int update(String sql, Statement statement) throws Exception { System.out.println(sql); int updateResult = 0; try { updateResult = statement.executeUpdate(sql); if (updateResult >= 0) { // 插入数据库成功 // 关闭 JDBC 连接 jdbcConnection.close(); } else { // 插入出现异常 // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); } } catch (Exception e) { // 事件回滚机制 jdbcConnection.rollback(); // 关闭 JDBC 连接 jdbcConnection.close(); throw new Exception("数据库修改异常"); } return updateResult; }
唯一需要注意的就是事物回滚,什么时候需要回滚,什么时候需要抛出异常,自己需要特别注意一下!
最头疼的就是查找,因为要找到返回对象类型(User ?List<User> ? Student ? List<Student > ....)所以,需要xml中所有实体类,然后意义比较获得的是哪种实体类,通过反射,创建该实体类对象。
@Override public Object query(String sql, Statement statement, String returnType, Object o, Class<?> c) throws ParseException, IllegalAccessException, InvocationTargetException { // 创建集合对象 List<Object> lists = new LinkedList<>(); // 获得所有 spring 容器中的实体类对象 List<Object> ClassTypes = daoBeanListener.getInstance(); // 将结果集封装到对象中 Map<String, String> objMap = new HashMap<>(); // 临时存储的集合对象 Map<String, Object> tempMap = new HashMap<>(); int columCount = 0; try { // 查询 System.out.println("查询前的sql:" + sql); ResultSet rs = statement.executeQuery(sql); // 指针置于最后一位 rs.last(); // 获得对象数量 int row = rs.getRow(); System.out.println("row:" + row); // 指针还原 rs.beforeFirst(); // 判断结果类型 for (Object classType : ClassTypes) { if (classType.getClass().isInstance(o)) { // 判断 sql 语句对象类型是否匹配已有实体类 // 将找到实体对象的标志位设置为 true returnTypeValidFlag = true; } } // 判断返回类型 if (returnType.equals("java.util.List")) { // 返回类型为 List log.info("本此查询为 list 类型"); // 判断是否符合类型标准 returnTypeValid = false; queryObjectValid.validCheck(row, returnTypeValid); } else { // 返回为普通对象 System.out.println("本次查询为 对象 类型"); // 判断是否符合类型标准 returnTypeValid = true; queryObjectValid.validCheck(row, true); } if (!returnTypeValidFlag) { // 没有该实体对象 throw new ClassNotFoundException(o + "找不到对应的实体类"); } Method[] methodsPojo = c.getMethods(); Object o2 = new Object(); while (rs.next()) { // 获得对象的所有属性 Field[] declaredFields = c.getDeclaredFields(); for (Field declaredField : declaredFields) { // 属性类型 String attributeType = String.valueOf(declaredField.getType()); // 属性名称 String name = declaredField.getName(); // 属性类型封装 String parameterType = ObjAllClassNameTransFormImpl.getParameterValue(attributeType); // 得到结果集 Object resultItem = rs.getObject(name); // 将结果集放入 map 中 objMap.put(name, String.valueOf(resultItem)); // 获得对象,设置集合属性 pojoInject.setAttribute(c, c.newInstance(), objMap); // 根据实体对象,使用 set() 属性 Object o1 = c.newInstance(); o1 = pojoInject.setAttribute(c, o1, objMap); o2 = o1; } // 将数据封装到集合中 // 判断返回类型 if (returnType.equals("java.util.List")) { // 返回类型为 List lists.add(o2); } else { // 返回为普通对象 return o2; } // 计数器 + 1 columCount++; } // 关闭 JDBC 连接 rs.close(); jdbcConnection.close(); return lists; } catch (SQLException | InstantiationException | ClassNotFoundException e) { e.printStackTrace(); } throw new NullPointerException("数据库查找异常!"); // return null; }
找到实体类方式如下:我这里用了单例模式,通过获得 spring.xml 中的 <bean></bean>对象,得到所有实体类,然后判断哪一个符合本次查询
package com.nchu.software.util.beanFactory.DaoBeanFactory; import org.springframework.context.ApplicationContext; import org.springframework.context.support.ClassPathXmlApplicationContext; import org.springframework.stereotype.Component; import java.util.LinkedList; import java.util.List; /** * Create by @author Ljh 2021/6/4 1:11 * * 单例模式 加载 dao bean 对象 */ @Component public class DaoBeanListenerSingleton { private static List<Object> daoBeanLists; static { daoBeanLists = new LinkedList<>(); } private static final DaoBeanListenerSingleton d = new DaoBeanListenerSingleton(); /** * 获得 dao 对象 bean * @return */ public List<Object> getInstance() { String config="spring.xml"; ApplicationContext ctx = new ClassPathXmlApplicationContext(config); // 对象名 String beans[] = ctx.getBeanDefinitionNames(); for(String bean : beans){ // 通过对象名获得类名 // Class<?> ClassType = ctx.getType(bean); Object ClassType = ctx.getBean(bean); daoBeanLists.add(ClassType); } return daoBeanLists; // return d; } }
最后就是启动的测试,通过 SqlSession 中 getMapper() 获得代理对象,
然后UserMapper mapper = sqlSession.getMapper(UserMapper.class);就OK了!
@Component @Slf4j public class SqlSession { @Autowired private MybatisApplication handler; /** * 获得代理对象 * @param obj * @param <T> * @return */ public <T> T getMapper(Class<T> obj) { // MybatisApplication handler = new MybatisApplication(); return (T) handler.getproxy(obj); } }
package com.nchu.software.util.beanFactory; import com.nchu.software.dao.UserMapper; import com.nchu.software.pojo.User; import com.nchu.software.sqlAnnotation.MybatisApplication; //import com.nchu.software.sqlAnnotation.MybatisRun; import com.nchu.software.sqlAnnotation.SqlSession; import com.nchu.software.util.beanFactory.DaoBeanFactory.DaoBeanListenerSingleton; import lombok.SneakyThrows; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.context.event.ContextRefreshedEvent; import org.springframework.stereotype.Component; import org.springframework.stereotype.Service; import java.util.LinkedList; import java.util.List; /** * Create by @author Ljh 2021/6/4 0:06 * * 在 spring 容器加载之后启动 */ @Component public class ApplicationLoadListener implements ApplicationListener<ContextRefreshedEvent> { @Autowired private MybatisApplication mybatisApplication; @Autowired private SqlSession sqlSession; @SneakyThrows @Override public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) { // mybatis 主启动 UserMapper mapper = sqlSession.getMapper(UserMapper.class); List<User> queryResult = mapper.selectUserListByNoParam(20); System.out.println(queryResult); mapper.selectUserListByNoParam(20); mapper.insertAll(6,"赵高", 20); int insertResult = mapper.insertTwo(6, "黄三", 50, 7, "赵高", 20); System.out.println("insertResult:" + insertResult); int updateResult = mapper.updateOne(70, 6); System.out.println("updateResult:" + updateResult); int deleteResult = mapper.deleteOne( 6); System.out.println("deleteResult:" + deleteResult); } }
好了,框架答题就是这样,有些细节,比如,类型转换,具体返回结果如何解析这些细节,需要调试程序的时候一点点的来,如果有需要请在我个人 gitee 中下载源码https://gitee.com/Nchusw/mybatis-annotation-frame.git,里面有详细的注释