文末附有源码地址
博客为代码完成后才开始整理,展示的代码都是最终代码
介绍:
个人学习,代码写的不够好,有点乱
通过学习mybatis源码模拟实现Mybatis(在完善)
已实现:
crud(实现XML配置,注解配置写过一点,原理差不多不写了)、
事务、
数据库连接池、
动态sql(一部分,原理没问题了,以后有时间再完善)、
一级二级缓存
一、如何着手
我们从这段代码看起:
//1.读取配置文件
InputStream in = Resources.getResourceAsStream("SqlMapConfig.xml");
//2.创建SqlSessionFactory工厂
SqlSessionFactoryBuilder builder = new SqlSessionFactoryBuilder();
SqlSessionFactory factory;
factory = builder.build(in);
//3.使用工厂生产SqlSession对象
SqlSession session = factory.openSession(true);
//SqlSession session = SessionUtil.getSession();
//4.使用SqlSession创建Dao接口的代理对象
IUserDao userDao = session.getMapper(IUserDao.class);
//5.使用代理对象执行方法
// List<User> user = userDao.findAll();
//User u = new User();
//u.setSex("男");
//User user = userDao.findById("41");
// List<User> user = userDao.findSome("%王%","男");
List<User> user = userDao.findSomeOgnl(new User(105,"%王%",new Date(1920,11,23),null,"江西"));
// userDao.insertUser(new User(105,"小q",new Date(1920,11,23),"男","江西"));
// userDao.deletebyId("49");
// long time = System.currentTimeMillis();
// java.sql.Date date = new java.sql.Date(time);
// java.sql.Timestamp timestamp = new Timestamp(time);
//userDao.update(new User(42,"小花k",timestamp,"女","江西"));
// System.out.println(user);
for(User u : user){
System.out.println(u);
}
//6.释放资源
//session.commit();
session.close();
这段代码展示了如何调用使用mybaits执行sql,期间夹杂着我的一些测试用例。
具体流程我已经在注释上写的很清晰了。用工厂模式生产sqlsession,再创建代理对象。我们所需要做的是将这些类实现出来。
二、配置文件的读取
读取解析xml文件有许多方式如:DOM、DOM4J、SAX。本人使用的是dom4j。
- 主配置文件的读取
首先来看一下mybatisXML配置:
<?xml version="1.0" encoding="UTF-8"?>
<!-- mybatis的主配置文件 -->
<configuration>
<!-- 配置环境 -->
<environments default="mysql">
<!-- 配置mysql的环境-->
<environment id="mysql">
<!-- 配置事务的类型-->
<transactionManager type="JDBC"></transactionManager>
<!-- 配置数据源(连接池) -->
<dataSource type="POOLED">
<!-- 配置连接数据库的4个基本信息 -->
<property name="driver" value="com.mysql.cj.jdbc.Driver"/>
<property name="url" value="jdbc:mysql://localhost:3306/eesy_mybatis?serverTimezone=GMT%2B8&useSSL=false"/>
<property name="username" value="root"/>
<property name="password" value=""/>
</dataSource>
</environment>
</environments>
<!-- 指定映射配置文件的位置,映射配置文件指的是每个dao独立的配置文件 -->
<mappers>
<mapper resource="com/baiye/dao/IUserDao.xml"/>
<!-- <mapper class="com.baiye.www.dao.IUserDao"/>-->
</mappers>
</configuration>
我们要将其解析读取,下面是代码:
public static Configuration loadConfiguration(InputStream in) {
Configuration config = new Configuration();
SAXReader saxReader = new SAXReader();
try {
Document document = saxReader.read(in);
Element root = document.getRootElement();
//xpath '//'表示在当先所选节点查找
List dataSourceNode = root.selectNodes("//dataSource");
String type = ((Element) dataSourceNode.get(0)).attributeValue("type");
config.setDataSourceType(type);
List nodes = root.selectNodes("//property");
Iterator iterator = nodes.iterator();
while (iterator.hasNext()) {
Element element = (Element) iterator.next();
String name = element.attributeValue("name");
String value = element.attributeValue("value");
if ("driver".equals(name)) {
config.setDriver(value);
} else if ("url".equals(name)) {
config.setUrl(value);
} else if ("username".equals(name)) {
config.setUsername(value);
} else if ("password".equals(name)) {
config.setPassword(value);
} else {
throw new RuntimeException("xml sql config error!");
}
}
//mappers
List mapperList = root.selectNodes("//mappers/mapper");
Iterator mapperIterator = mapperList.iterator();
Map<String, Mapper> mappers = new HashMap<>();
while (mapperIterator.hasNext()) {
Element element = (Element) mapperIterator.next();
Attribute attribute = element.attribute("resource");
if (attribute != null) {
//使用xml
String classPath = attribute.getValue();
Map<String, Mapper> mapper = loadMapperXMLConfiguration(classPath);
mappers.putAll(mapper);
} else if (element.attribute("class") != null) {
//使用注解
String classPath = element.attribute("class").getValue();
Map<String, Mapper> mapper = loadMapperAnnotation(classPath);
mappers.putAll(mapper);
} else {
throw new RuntimeException("xml mapper config error");
}
}
config.setMappers(mappers);
} catch (Exception e) {
e.printStackTrace();
System.out.println("loadConfiguration error!");
}
if (config.getDataSourceType().equals("POOLED")) {
config.setEnvironment(new Environment(null, null, new PooledDataSourceFactory(config).getDataSource()));
} else if (config.getDataSourceType().equals("UNPOOLED")) {
config.setEnvironment(new Environment(null, null, new UnpooledDataSourceFactory(config).getDataSource()));
}
return config;
}
- mapper配置文件的读取
配置文件如下:
<?xml version="1.0" encoding="UTF-8"?>
<mapper namespace="com.baiye.www.dao.IUserDao">
<!--配置查询所有-->
<select id="findById" resultType="com.baiye.www.domain.User" parameterType="Integer">
select * from user where id= #{id}
</select>
<select id="findAll" resultType="com.baiye.www.domain.User">
select * from user
</select>
<select id="findSome" resultType="com.baiye.www.domain.User">
select * from `user` where username LIKE #{username} and sex = #{sex}
</select>
<select id="findSomeOgnl" resultType="com.baiye.www.domain.User" parameterType="com.baiye.www.domain.User">
select * from `user` where username LIKE #{username} and username LIKE #{username}
<if test="sex != null">
and sex = #{sex}
</if>
</select>
<insert id="insertUser" parameterType="com.baiye.www.domain.User">
insert into user(username,birthday,sex,address) values(#{username},#{birthday},#{sex},#{address})
</insert>
<delete id="deletebyId" parameterType="Integer">
delete from user where id = #{id}
</delete>
<update id="update" parameterType="com.baiye.www.domain.User">
update user set username=#{username},birthday=#{birthday},sex=#{sex} where id=#{id}
</update>
</mapper>
代码如下,分别有读取注解配置的和读取xml配置的:
/**
* 注解配置 根据传入的全限定类名封装mapper(sql,resultType)
*
* @param classPath
* @return
*/
private static Map<String, Mapper> loadMapperAnnotation(String classPath) {
Map<String, Mapper> mappers = new HashMap<String, Mapper>();
Mapper mapper = new Mapper();
try {
Class daoClass = Class.forName(classPath);
Method[] methods = daoClass.getMethods();
for (Method method : methods) {
if (method.isAnnotationPresent(Select.class)) {
Select selectAnnotation = method.getAnnotation(Select.class);
String sql = selectAnnotation.value();
mapper.setSql(sql);
String methodName = method.getName();
Type type = method.getGenericReturnType();
//判断type是不是参数化的类型
if (type instanceof ParameterizedType) {
ParameterizedType ptype = (ParameterizedType) type;
Type[] types = ptype.getActualTypeArguments();
Class domainClass = (Class) types[0];
//获取domainClass的类名
String resultType = domainClass.getName();
mapper.setResultType(resultType);
}
String key = classPath + "." + methodName;
mappers.put(key, mapper);
}
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
return mappers;
}
/**
* xml配置 根据传入的全限定类名封装mapper(sql,resultType)
*
* @param xmlPath
* @return
*/
private static Map<String, Mapper> private static Map<String, Mapper> loadMapperXMLConfiguration(String xmlPath) {
Map<String, Mapper> mappers = new HashMap<String, Mapper>();
InputStream in = null;
try {
in = Resources.getResourceAsStream(xmlPath);
SAXReader saxReader = new SAXReader();
Document document = saxReader.read(in);
Element root = document.getRootElement();
String namespace = root.attributeValue("namespace");
List selectNodes = root.selectNodes("//select");
Map<String, Mapper> selectMappers = NodeToMapper(namespace,selectNodes, SqlCommandType.SELECT);
List insertNodes = root.selectNodes("//insert");
Map<String, Mapper> insertMappers = NodeToMapper(namespace,insertNodes, SqlCommandType.INSERT);
List updateNodes = root.selectNodes("//update");
Map<String, Mapper> updateMappers = NodeToMapper(namespace,updateNodes, SqlCommandType.UPDATE);
List deleteNodes = root.selectNodes("//delete");
Map<String, Mapper> deleteMappers =NodeToMapper(namespace,deleteNodes,SqlCommandType.DELETE);
mappers.putAll(selectMappers);
mappers.putAll(insertMappers);
mappers.putAll(updateMappers);
mappers.putAll(deleteMappers);
in.close();
} catch (Exception e) {
e.printStackTrace();
}
return mappers;
}
private static Map<String, Mapper> NodeToMapper(String namespace, List list, SqlCommandType type){
Map<String, Mapper> mappers = new HashMap<String, Mapper>();
Iterator iterator = list.iterator();
while (iterator.hasNext()) {
Element element = (Element) iterator.next();
String id = element.attributeValue("id");
String resultType = element.attributeValue("resultType");
String parameterType = element.attributeValue("parameterType");
String resultMap = element.attributeValue("resultMap");
String key = namespace + "." + id;
String sql = element.getText();
Mapper mapper = new Mapper(key,sql, element, resultType, parameterType, resultMap,type);
mappers.put(key, mapper);
}
return mappers;
}
标记Sql类型枚举类
public enum SqlCommandType {
UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
}
读取完成后我们将信息封装成Configuration类和Mapper类。
public class Configuration {
protected Environment environment;
private String dataSourceType;
private String driver;
private String url;
private String username;
private String password;
private Cache perpetualCache = new PerpetualCache();
private boolean enableCache;
//全限定类名+方法名,mapper
private Map<String, Mapper> mappers = new HashMap<String, Mapper>();
......
}
public class Mapper {
private String sql;
private Element element;
private String resultType;
private String parameterType;
private String resultMap;
......
}
三、工厂类的创建
首先创建SqlSessionFactoryBuilder类,根据读取的配置文件创建工厂,支持多种参数。
public class SqlSessionFactoryBuilder {
private SqlSessionFactory build(InputStream inputStream, String environment, Properties properties) {
Configuration configuration = XMLConfigBuilder.loadConfiguration(inputStream);
return build(configuration);
}
public SqlSessionFactory build(InputStream inputStream) {
return build(inputStream, null, null);
}
public SqlSessionFactory build(Configuration configuration) {
return new DefaultSqlSessionFactory(configuration);
}
}
根据mybatis源码的设计,我们得分别创建SqlSessionFactory和SqlSession接口,并实现对应的默认实现类:DefaultSqlSession、DefaultSqlSession。
public interface SqlSessionFactory {
SqlSession openSession();
SqlSession openSession(boolean autoCommit);
}
public interface SqlSession {
<T> T getMapper(Class<T> daoInterfaceClass);
<T> T selectOne(String mapperName, Object[] parameter);
<T> List<T> selectList(String mapperName, Object[] parameter);
int insert(String mapperName, Object[] parameter);
int update(String mapperName, Object[] parameter);
int delete(String mapperName, Object[] parameter);
void clearCache();
/**
* 若未提交直接关闭,会回滚
*/
void close();
void rollback();
void commit();
}
DefaultSqlSession中包含配置涉及到事务等内容:
public class DefaultSqlSession implements SqlSession {
private final boolean autoCommit;
protected Configuration configuration;
private Executor executor;
private boolean dirty;
.....
}
四、代理对象
下面这个方法为关键:
@Override
public <T> T getMapper(Class<T> daoInterfaceClass) {
return (T) Proxy.newProxyInstance(daoInterfaceClass.getClassLoader(), new Class[]{daoInterfaceClass}, new MapperProxy(this));
}
这个方法得到代理对象。
下面为MapperProxy类:
package com.baiye.www.mybaits.proxy;
import com.baiye.www.mybaits.annotation.Param;
import com.baiye.www.mybaits.sqlsession.SqlSession;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
/**
* Created with IntelliJ IDEA.
*
* @Author: baiye
* @Date: 2021/07/13/16:54
* @Description:
*/
public class MapperProxy implements InvocationHandler {
//map的key是全限定类名+方法名
//private Map<String, Mapper> mappers;
private SqlSession sqlSession;
public MapperProxy(SqlSession sqlSession) {
this.sqlSession = sqlSession;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
System.out.println(method.getDeclaringClass().getName() + "." + method.getName());
//最终还是将执行方法转给 sqlSession,因为 sqlSession 里面封装了 Executor
//根据调用方法的类名和方法名以及参数,传给 sqlSession 对应的方法
if (args != null && args.length > 1) {
Annotation[][] annotation = method.getParameterAnnotations();
for (int i = 0; i < args.length; i++) {
Map<String, String> map = new HashMap(1);
Param anno = (Param) annotation[i][0];
map.put(anno.value(), args[i] + "");
args[i] = map;
}
}
String namespace=method.getDeclaringClass().getName() + "." + method.getName();
SqlCommandType sqlType = sqlSession.getConfiguration().getMappers().get(namespace).getSqlType();
if (sqlType.equals(SqlCommandType.SELECT)) {
if (Collection.class.isAssignableFrom(method.getReturnType())) {
return sqlSession.selectList(namespace, args);
}
return sqlSession.selectOne(namespace, args);
} else if (sqlType.equals(SqlCommandType.UPDATE)||sqlType.equals(SqlCommandType.DELETE)||sqlType.equals(SqlCommandType.INSERT)) {
return sqlSession.update(namespace, args);
} else {
throw new RuntimeException("no such SqlType");
}
}
}
五、执行sql语句
代理对象调用sqlsession的方法。最终session调用Executor执行sql语句。
下面为构造sql语句写的类,没有参考mybatis极为粗糙和简单:
public static String paramToSql(String originalSql, Object[] object) throws InvocationTargetException, IllegalAccessException, IntrospectionException {
int paramCount = StringUtil.getTargetStringNum(originalSql, "#");
//没有参数
if (paramCount < 1) {
return originalSql;
} else if (paramCount == 1) { //只有一个参数即只有一个基本数据类型,转化为字符串
String regex = "#\\{([^}])*}";
//将 sql 语句中的 #{*} 替换为 ?
return originalSql.replaceAll(regex, "\"" + object[0] + "" + "\"");
} else if (paramCount > 1 && object.length == 1) { //多个参数,且传实体对象
Class obj = object[0].getClass();
Field[] declaredFields = obj.getDeclaredFields();
for (Field field : declaredFields) {
String fieldName = field.getName();
PropertyDescriptor propertyDescriptor = new PropertyDescriptor(fieldName, obj);
Method readMethod = propertyDescriptor.getReadMethod();
Object o = readMethod.invoke(object[0]);
if (o != null) {
originalSql = originalSql.replace("#{" + fieldName + "}", "\"" + o + "" + "\"");
} else {
originalSql = originalSql.replace("#{" + fieldName + "}", "null");
}
}
return originalSql;
} else { //多个参数,传的也是多个基本类型
for (Object o : object) {
//前面已经是String,String,防止意外改为String,Object
HashMap<String, Object> map = (HashMap<String, Object>) o;
for (Map.Entry<String, Object> entry : map.entrySet()) {
String regex = "#\\{" + entry.getKey() + "}";
//将 sql 语句中的 #{*} 替换为实际参数
if (entry.getValue() instanceof Integer || entry.getValue().equals("1 OR 1")) {
originalSql = originalSql.replaceAll(regex, entry.getValue() + "");
} else {
if (entry.getValue() != null) {
originalSql = originalSql.replaceAll(regex, "\"" + entry.getValue() + "" + "\"");
} else {
originalSql = originalSql.replaceAll(regex, "null");
}
}
}
}
return originalSql;
}
}
六、结果的封装
执行查询的sql得到的结果通过在Executor以下代码封装:
while (resultSet.next()) {
E pojo = (E) pojoClass.newInstance();
ResultSetMetaData metaData = resultSet.getMetaData();
for (int i = 1; i <= metaData.getColumnCount(); i++) {
String columnName = metaData.getColumnName(i);
Object value = resultSet.getObject(columnName);
columnName = StringUtil.underlineToHump(columnName);
PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, pojoClass);
Method writeMethod = propertyDescriptor.getWriteMethod();
writeMethod.invoke(pojo, value);
}
list.add(pojo);
}
项目地址
github:https://github.com/Alice-175/Mybaits
gitee:https://gitee.com/alice-175/Mybaits