手撸MyBatis(二)模拟实现

超详细的Java知识点路线图


概述

前面文章中,通过看源码的方式,带大家了解了一下MyBatis的执行过程,本文我们将自己编写代码,模拟MyBatis的简单实现。

回顾

先回顾MyBatis的实现过程:

  1. 通过SQLSessionFactoryBuilder创建SQLSessionFactory时,将核心配置文件中configuration节点的内容,解析到SQLSessionFactory中
  2. 通过SQLSessionFactory获得SqlSession时,返回DefaultSqlSession
  3. 调用SQLSessionFactory的getMapper方法时,返回了DAO接口的代理对象,代理对象在invoke方法中实现了增删改查操作
  4. 具体的增删改查操作通过JDBC+SQL实现

编码实现

下面我们模拟实现上面的过程:
1、添加配置文件mock-config.xml

<?xml version="1.0" encoding="UTF-8" ?>
<configuration>
    <environments default="develop">
        <environment id="develop">
            <transactionManager type="JDBC"></transactionManager>
            <dataSource type="POOLED">
                <property name="driver" value="com.mysql.jdbc.Driver"/>
                <property name="url" value="jdbc:mysql://localhost:3306/test_db"/>
                <property name="username" value="root"/>
                <property name="password" value="123456"/>
            </dataSource>
        </environment>
    </environments>
    <mappers>
        <!--<package name="com.qf.mbs.mapper"/>-->
        <mapper resource="com/qf/mbs/mapper/UserDAO.xml"/>
        <mapper class="com.qf.mbs.dao.UserDAO"/>
    </mappers>
</configuration>

DAO接口:

public interface UserDAO {

    List<User> findAll();
    User findById(Integer id);
    void addUser(User user);
    void updateUser(User user);
    void deleteUser(Integer id);
}

Mapper文件:

<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper
        PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
        "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.qf.mbs.dao.UserDAO">
    <insert id="addUser" parameterType="com.qf.mbs.po.User">
        <selectKey keyProperty="id" resultType="int" order="AFTER">
            SELECT LAST_INSERT_ID()
        </selectKey>
        INSERT INTO USER (USERNAME,BIRTHDAY,SEX,ADDRESS) VALUES (#{username},#{birthday},#{sex},#{address})
    </insert>
    <update id="updateUser" parameterType="com.qf.mbs.po.User">
        UPDATE  USER set username=#{username},birthday=#{birthday},sex=#{sex},address=#{address} where id=#{id}
    </update>
    <delete id="deleteUser" parameterType="java.lang.Integer">
        DELETE FROM USER where id=#{id}
    </delete>
    <select id="findById" parameterType="java.lang.Integer" resultType="com.qf.mbs.po.User">
        SELECT * FROM USER where id=#{id}
    </select>
    <select id="findAll" resultType="com.qf.mbs.po.User">
        select * from user
    </select>
</mapper>

定义一个解析类,完成对数据源解析工作:

/**
配置文件的解析类
*/
public class ConfigParser {
    //保存数据源信息
	private Map<String,String> dataSource = new HashMap<>();
	//保存映射文件路径
	private List<String> resourceMappers = new ArrayList<>();
	//保存class映射类名
    private List<String> classMappers = new ArrayList<>();

    /**配置解析*/
    public void parseConfig(InputStream inputStream) throws DocumentException {
        SAXReader reader = new SAXReader();
        reader.setEntityResolver(new EntityResolver() {
            @Override
            public InputSource resolveEntity(String publicId, String systemId) throws SAXException, IOException {
                return new InputSource(new ByteArrayInputStream("".getBytes()));
            }
        });
        Document doc = reader.read(inputStream);
        Element rootElement = doc.getRootElement();
        List<Element> properties = rootElement.element("environments").element("environment").element("dataSource").elements();
        for(Element property : properties){
            if("driver".equals(property.attributeValue("name"))){
                dataSource.put("driver",property.attributeValue("value"));
            }
            if("url".equals(property.attributeValue("name"))){
                dataSource.put("url",property.attributeValue("value"));
            }
            if("username".equals(property.attributeValue("name"))){
                dataSource.put("username",property.attributeValue("value"));
            }
            if("password".equals(property.attributeValue("name"))){
                dataSource.put("password",property.attributeValue("value"));
            }
        }
        List<Element> mapperEles = rootElement.element("mappers").elements();
        for(Element mapper : mapperEles){
           if(mapper.attribute("resource") != null){
               resourceMappers.add(mapper.attributeValue("resource"));
           }
           if(mapper.attribute("class") != null){
               classMappers.add(mapper.attributeValue("class"));
           }
        }
    }

    public Map<String, String> getDataSource() {
        return dataSource;
    }

    public List<String> getResourceMappers() {
        return resourceMappers;
    }

    public List<String> getClassMappers() {
        return classMappers;
    }
}

测试:

@Test
public void testParse(){
    ConfigParser parser = new ConfigParser();
    try {
        parser.parseConfig(Test2.class.getClassLoader().getResourceAsStream("mock-config.xml"));
        System.out.println(parser.getDataSource());
        System.out.println(parser.getClassMappers());
        System.out.println(parser.getResourceMappers());
    } catch (DocumentException e) {
        e.printStackTrace();
    }
}

编写一个类,保存Mapper方法相关的信息

/**
 * SQL映射类
 */
public class Mapper {
    //sql类型 insert update delete select
    private String sqlType;
    //Mapper方法名
    private String methodName;
    //sql语句
    private String sql;
    //参数类型
    private String paramType;
    //返回类型
    private String returnType;
   ....
}

定义一个解析类,完成对映射文件的解析工作:

/**
映射文件的解析类
*/
public class MapperParser {

    private Map<String,Mapper> mappers = new HashMap<>();

    public void parseMappers(InputStream inputStream) throws DocumentException {
        SAXReader reader = new SAXReader();
        reader.setEntityResolver(new EntityResolver() {
            @Override
            public InputSource resolveEntity(String publicId, String systemId) throws SAXException, IOException {
                return new InputSource(new ByteArrayInputStream("".getBytes()));
            }
        });
        Document doc = reader.read(inputStream);
        Element root = doc.getRootElement();
        List<Element> elements = root.elements();
        String namespace = root.attributeValue("namespace");
        for(Element e : elements){
            Mapper mapper = new Mapper();
            mapper.setSqlType(e.getName());
            mapper.setMethodName(e.attributeValue("id"));
            mapper.setParamType(e.attributeValue("parameterType"));
            mapper.setReturnType(e.attributeValue("resultType"));
            mapper.setSql(e.getText());
            mappers.put(namespace+"."+e.attributeValue("id"),mapper);
        }
    }

    public Map<String, Mapper> getMappers() {
        return mappers;
    }
}

测试:

@Test
public void testParse(){
    ConfigParser parser = new ConfigParser();
    MapperParser mParser = new MapperParser();
    try {
        parser.parseConfig(Test2.class.getClassLoader().getResourceAsStream("mock-config.xml"));
        System.out.println(parser.getDataSource());
        System.out.println(parser.getClassMappers());
        System.out.println(parser.getResourceMappers());
        for(String resource : parser.getResourceMappers()){
            System.out.println("resourc:"+resource);
            mParser.parseMappers(Test2.class.getClassLoader().getResourceAsStream(resource));
            System.out.println(mParser.getMappers());
        }
    } catch (DocumentException e) {
        e.printStackTrace();
    }
}

编写模拟SQLSession的类,通过动态代理获得Mapper对象

/**
 * 模拟SqlSession
 */
public interface MySqlSession {
    /**
     * 返回Mapper对象
     * @param clazz
     * @param <T>
     * @return
     */
    <T> T getMapper(Class<T> clazz);
}

public class MySqlSessionImpl implements MySqlSession {

    private Map<String,String> dataSource;
    private Map<String,Mapper> mappers;

    public MySqlSessionImpl(Map<String, String> dataSource, Map<String, Mapper> mappers) {
        this.dataSource = dataSource;
        this.mappers = mappers;
    }

    @Override
    public <T> T getMapper(Class<T> clazz) {
        XMLMapperProxy xmlMapperProxy = new XMLMapperProxy(dataSource,mappers,clazz);
        return (T) xmlMapperProxy.getProxy();
    }
}

动态代理的实现类

/**
 * XML映射文件的代理
 */
public class XMLMapperProxy implements InvocationHandler{

    private Map<String,String> dataSource;
    private Map<String,Mapper> mappers;
    private Class<?> clazz;
    private Connection connection;

    public XMLMapperProxy(Map<String, String> dataSource, Map<String,Mapper> mappers,Class clazz) {
        this.dataSource = dataSource;
        this.mappers = mappers;
        this.clazz = clazz;
        initDatasource(dataSource);
    }

    /**
     * 获得代理对象
     * @return
     */
    public Object getProxy(){
        return Proxy.newProxyInstance(this.getClass().getClassLoader(),new Class[]{clazz},this);
    }
    
    /**
     * 执行代理方法
     * @return
     */
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        String key = clazz.getName() + "." + method.getName();
        //通过类名取到Mapper对象
        Mapper mapper = mappers.get(key);
        System.out.println("mapper--->"+mapper);
        if(mapper == null){
            System.out.println(key+",has no mapper.");
            return null;
        }
        //获得PreparedStatement 对象执行sql
        PreparedStatement psmt = connection.prepareStatement(mapper.getSql());
        Object result = null;
        //判断是否查询命令
        if("select".equals(mapper.getSqlType())){
            //获得查询结果
            ResultSet rs = psmt.executeQuery();
            List<String> columns = new ArrayList<>();
            int count = rs.getMetaData().getColumnCount();
            for(int i = 1;i <= count;i++){
                columns.add(rs.getMetaData().getColumnName(i));
            }
            List selectRs = new ArrayList();
            //通过反射读取属性名,通过set方法给每个属性赋值
            while(rs.next()){
                Class<?> poClazz = Class.forName(mapper.getReturnType());
                Object po = poClazz.newInstance();
                for(String colName : columns){
                    for(Method med : poClazz.getMethods()){
                        if(med.getName().equalsIgnoreCase("set"+colName)){
                            try{
                                med.invoke(po,rs.getString(colName));
                            }catch (IllegalArgumentException ex) {
                                med.invoke(po, rs.getInt(colName));
                            }
                        }
                    }
                }
                selectRs.add(po);
            }
            result = selectRs;
        }else{
            result = psmt.executeUpdate();
        }
        connection.close();
        return result;
    }

	 /**
     * 初始化数据源
     * @return
     */
    private void initDatasource(Map<String,String> dataSource){
        try {
            Class.forName(dataSource.get("driver"));
            connection = DriverManager.getConnection(dataSource.get("url"),dataSource.get("username"),dataSource.get("password"));
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

模拟SQLSessionFactory工厂代码:

public class MySqlSessionFactory {

    private String config;

    public MySqlSessionFactory(String config) {
        this.config = config;
    }

    public MySqlSession openSession(){
        ConfigParser parser = new ConfigParser();
        MapperParser mParser = new MapperParser();
        try {
            parser.parseConfig(MySqlSessionFactory.class.getClassLoader().getResourceAsStream(config));
            System.out.println(parser.getDataSource());
            System.out.println(parser.getClassMappers());
            System.out.println(parser.getResourceMappers());
            for(String resource : parser.getResourceMappers()){
                System.out.println("resource:"+resource);
                mParser.parseMappers(MySqlSessionFactory.class.getClassLoader().getResourceAsStream(resource));
                System.out.println(mParser.getMappers());
            }
            return new MySqlSessionImpl(parser.getDataSource(),mParser.getMappers());
        } catch (DocumentException e) {
            e.printStackTrace();
        }
        return null;
    }
}

工厂的创建器

public class MySqlSessionFactoryBuilder {

    public MySqlSessionFactory build(String config){
        return new MySqlSessionFactory(config);
    }
}

测试:

@Test
public void testSelect(){
    MySqlSessionFactory factory = new MySqlSessionFactoryBuilder().build("mock-config.xml");
    MySqlSession mySqlSession = factory.openSession();
    UserDAO mapper = mySqlSession.getMapper(UserDAO.class);
    List<User> users = mapper.findAll();
	System.out.println("users--->"+users);
}

总结

这里我们参考MyBatis的源码,简单模拟了它的实现过程,这样我们会MyBatis框架有了一个更深入的理解。同样会更深刻掌握反射、动态代理、jdbc、面向对象等知识点,如果以后有需要,还可以编写属于自己的框架哦。

本文就到这里了,如果对你有用的话,就点个赞吧:)


大家如果需要学习其他Java知识点,戳这里 超详细的Java知识点汇总

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

恒哥~Bingo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值