概述
前面文章中,通过看源码的方式,带大家了解了一下MyBatis的执行过程,本文我们将自己编写代码,模拟MyBatis的简单实现。
回顾
先回顾MyBatis的实现过程:
- 通过SQLSessionFactoryBuilder创建SQLSessionFactory时,将核心配置文件中configuration节点的内容,解析到SQLSessionFactory中
- 通过SQLSessionFactory获得SqlSession时,返回DefaultSqlSession
- 调用SQLSessionFactory的getMapper方法时,返回了DAO接口的代理对象,代理对象在invoke方法中实现了增删改查操作
- 具体的增删改查操作通过JDBC+SQL实现
编码实现
下面我们模拟实现上面的过程:
1、添加配置文件mock-config.xml
<?xml version="1.0" encoding="UTF-8" ?>
<configuration>
<environments default="develop">
<environment id="develop">
<transactionManager type="JDBC"></transactionManager>
<dataSource type="POOLED">
<property name="driver" value="com.mysql.jdbc.Driver"/>
<property name="url" value="jdbc:mysql://localhost:3306/test_db"/>
<property name="username" value="root"/>
<property name="password" value="123456"/>
</dataSource>
</environment>
</environments>
<mappers>
<!--<package name="com.qf.mbs.mapper"/>-->
<mapper resource="com/qf/mbs/mapper/UserDAO.xml"/>
<mapper class="com.qf.mbs.dao.UserDAO"/>
</mappers>
</configuration>
DAO接口:
public interface UserDAO {
List<User> findAll();
User findById(Integer id);
void addUser(User user);
void updateUser(User user);
void deleteUser(Integer id);
}
Mapper文件:
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.qf.mbs.dao.UserDAO">
<insert id="addUser" parameterType="com.qf.mbs.po.User">
<selectKey keyProperty="id" resultType="int" order="AFTER">
SELECT LAST_INSERT_ID()
</selectKey>
INSERT INTO USER (USERNAME,BIRTHDAY,SEX,ADDRESS) VALUES (#{username},#{birthday},#{sex},#{address})
</insert>
<update id="updateUser" parameterType="com.qf.mbs.po.User">
UPDATE USER set username=#{username},birthday=#{birthday},sex=#{sex},address=#{address} where id=#{id}
</update>
<delete id="deleteUser" parameterType="java.lang.Integer">
DELETE FROM USER where id=#{id}
</delete>
<select id="findById" parameterType="java.lang.Integer" resultType="com.qf.mbs.po.User">
SELECT * FROM USER where id=#{id}
</select>
<select id="findAll" resultType="com.qf.mbs.po.User">
select * from user
</select>
</mapper>
定义一个解析类,完成对数据源解析工作:
/**
配置文件的解析类
*/
public class ConfigParser {
//保存数据源信息
private Map<String,String> dataSource = new HashMap<>();
//保存映射文件路径
private List<String> resourceMappers = new ArrayList<>();
//保存class映射类名
private List<String> classMappers = new ArrayList<>();
/**配置解析*/
public void parseConfig(InputStream inputStream) throws DocumentException {
SAXReader reader = new SAXReader();
reader.setEntityResolver(new EntityResolver() {
@Override
public InputSource resolveEntity(String publicId, String systemId) throws SAXException, IOException {
return new InputSource(new ByteArrayInputStream("".getBytes()));
}
});
Document doc = reader.read(inputStream);
Element rootElement = doc.getRootElement();
List<Element> properties = rootElement.element("environments").element("environment").element("dataSource").elements();
for(Element property : properties){
if("driver".equals(property.attributeValue("name"))){
dataSource.put("driver",property.attributeValue("value"));
}
if("url".equals(property.attributeValue("name"))){
dataSource.put("url",property.attributeValue("value"));
}
if("username".equals(property.attributeValue("name"))){
dataSource.put("username",property.attributeValue("value"));
}
if("password".equals(property.attributeValue("name"))){
dataSource.put("password",property.attributeValue("value"));
}
}
List<Element> mapperEles = rootElement.element("mappers").elements();
for(Element mapper : mapperEles){
if(mapper.attribute("resource") != null){
resourceMappers.add(mapper.attributeValue("resource"));
}
if(mapper.attribute("class") != null){
classMappers.add(mapper.attributeValue("class"));
}
}
}
public Map<String, String> getDataSource() {
return dataSource;
}
public List<String> getResourceMappers() {
return resourceMappers;
}
public List<String> getClassMappers() {
return classMappers;
}
}
测试:
@Test
public void testParse(){
ConfigParser parser = new ConfigParser();
try {
parser.parseConfig(Test2.class.getClassLoader().getResourceAsStream("mock-config.xml"));
System.out.println(parser.getDataSource());
System.out.println(parser.getClassMappers());
System.out.println(parser.getResourceMappers());
} catch (DocumentException e) {
e.printStackTrace();
}
}
编写一个类,保存Mapper方法相关的信息
/**
* SQL映射类
*/
public class Mapper {
//sql类型 insert update delete select
private String sqlType;
//Mapper方法名
private String methodName;
//sql语句
private String sql;
//参数类型
private String paramType;
//返回类型
private String returnType;
....
}
定义一个解析类,完成对映射文件的解析工作:
/**
映射文件的解析类
*/
public class MapperParser {
private Map<String,Mapper> mappers = new HashMap<>();
public void parseMappers(InputStream inputStream) throws DocumentException {
SAXReader reader = new SAXReader();
reader.setEntityResolver(new EntityResolver() {
@Override
public InputSource resolveEntity(String publicId, String systemId) throws SAXException, IOException {
return new InputSource(new ByteArrayInputStream("".getBytes()));
}
});
Document doc = reader.read(inputStream);
Element root = doc.getRootElement();
List<Element> elements = root.elements();
String namespace = root.attributeValue("namespace");
for(Element e : elements){
Mapper mapper = new Mapper();
mapper.setSqlType(e.getName());
mapper.setMethodName(e.attributeValue("id"));
mapper.setParamType(e.attributeValue("parameterType"));
mapper.setReturnType(e.attributeValue("resultType"));
mapper.setSql(e.getText());
mappers.put(namespace+"."+e.attributeValue("id"),mapper);
}
}
public Map<String, Mapper> getMappers() {
return mappers;
}
}
测试:
@Test
public void testParse(){
ConfigParser parser = new ConfigParser();
MapperParser mParser = new MapperParser();
try {
parser.parseConfig(Test2.class.getClassLoader().getResourceAsStream("mock-config.xml"));
System.out.println(parser.getDataSource());
System.out.println(parser.getClassMappers());
System.out.println(parser.getResourceMappers());
for(String resource : parser.getResourceMappers()){
System.out.println("resourc:"+resource);
mParser.parseMappers(Test2.class.getClassLoader().getResourceAsStream(resource));
System.out.println(mParser.getMappers());
}
} catch (DocumentException e) {
e.printStackTrace();
}
}
编写模拟SQLSession的类,通过动态代理获得Mapper对象
/**
* 模拟SqlSession
*/
public interface MySqlSession {
/**
* 返回Mapper对象
* @param clazz
* @param <T>
* @return
*/
<T> T getMapper(Class<T> clazz);
}
public class MySqlSessionImpl implements MySqlSession {
private Map<String,String> dataSource;
private Map<String,Mapper> mappers;
public MySqlSessionImpl(Map<String, String> dataSource, Map<String, Mapper> mappers) {
this.dataSource = dataSource;
this.mappers = mappers;
}
@Override
public <T> T getMapper(Class<T> clazz) {
XMLMapperProxy xmlMapperProxy = new XMLMapperProxy(dataSource,mappers,clazz);
return (T) xmlMapperProxy.getProxy();
}
}
动态代理的实现类
/**
* XML映射文件的代理
*/
public class XMLMapperProxy implements InvocationHandler{
private Map<String,String> dataSource;
private Map<String,Mapper> mappers;
private Class<?> clazz;
private Connection connection;
public XMLMapperProxy(Map<String, String> dataSource, Map<String,Mapper> mappers,Class clazz) {
this.dataSource = dataSource;
this.mappers = mappers;
this.clazz = clazz;
initDatasource(dataSource);
}
/**
* 获得代理对象
* @return
*/
public Object getProxy(){
return Proxy.newProxyInstance(this.getClass().getClassLoader(),new Class[]{clazz},this);
}
/**
* 执行代理方法
* @return
*/
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String key = clazz.getName() + "." + method.getName();
//通过类名取到Mapper对象
Mapper mapper = mappers.get(key);
System.out.println("mapper--->"+mapper);
if(mapper == null){
System.out.println(key+",has no mapper.");
return null;
}
//获得PreparedStatement 对象执行sql
PreparedStatement psmt = connection.prepareStatement(mapper.getSql());
Object result = null;
//判断是否查询命令
if("select".equals(mapper.getSqlType())){
//获得查询结果
ResultSet rs = psmt.executeQuery();
List<String> columns = new ArrayList<>();
int count = rs.getMetaData().getColumnCount();
for(int i = 1;i <= count;i++){
columns.add(rs.getMetaData().getColumnName(i));
}
List selectRs = new ArrayList();
//通过反射读取属性名,通过set方法给每个属性赋值
while(rs.next()){
Class<?> poClazz = Class.forName(mapper.getReturnType());
Object po = poClazz.newInstance();
for(String colName : columns){
for(Method med : poClazz.getMethods()){
if(med.getName().equalsIgnoreCase("set"+colName)){
try{
med.invoke(po,rs.getString(colName));
}catch (IllegalArgumentException ex) {
med.invoke(po, rs.getInt(colName));
}
}
}
}
selectRs.add(po);
}
result = selectRs;
}else{
result = psmt.executeUpdate();
}
connection.close();
return result;
}
/**
* 初始化数据源
* @return
*/
private void initDatasource(Map<String,String> dataSource){
try {
Class.forName(dataSource.get("driver"));
connection = DriverManager.getConnection(dataSource.get("url"),dataSource.get("username"),dataSource.get("password"));
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
模拟SQLSessionFactory工厂代码:
public class MySqlSessionFactory {
private String config;
public MySqlSessionFactory(String config) {
this.config = config;
}
public MySqlSession openSession(){
ConfigParser parser = new ConfigParser();
MapperParser mParser = new MapperParser();
try {
parser.parseConfig(MySqlSessionFactory.class.getClassLoader().getResourceAsStream(config));
System.out.println(parser.getDataSource());
System.out.println(parser.getClassMappers());
System.out.println(parser.getResourceMappers());
for(String resource : parser.getResourceMappers()){
System.out.println("resource:"+resource);
mParser.parseMappers(MySqlSessionFactory.class.getClassLoader().getResourceAsStream(resource));
System.out.println(mParser.getMappers());
}
return new MySqlSessionImpl(parser.getDataSource(),mParser.getMappers());
} catch (DocumentException e) {
e.printStackTrace();
}
return null;
}
}
工厂的创建器
public class MySqlSessionFactoryBuilder {
public MySqlSessionFactory build(String config){
return new MySqlSessionFactory(config);
}
}
测试:
@Test
public void testSelect(){
MySqlSessionFactory factory = new MySqlSessionFactoryBuilder().build("mock-config.xml");
MySqlSession mySqlSession = factory.openSession();
UserDAO mapper = mySqlSession.getMapper(UserDAO.class);
List<User> users = mapper.findAll();
System.out.println("users--->"+users);
}
总结
这里我们参考MyBatis的源码,简单模拟了它的实现过程,这样我们会MyBatis框架有了一个更深入的理解。同样会更深刻掌握反射、动态代理、jdbc、面向对象等知识点,如果以后有需要,还可以编写属于自己的框架哦。
本文就到这里了,如果对你有用的话,就点个赞吧:)
大家如果需要学习其他Java知识点,戳这里 超详细的Java知识点汇总