利用dom4j解析相关配置文件并实现mybatis底层源码

目录

1.首先要创建一个maven项目,并导入相关依赖(dom4j)

2.在src\main\resources目录下创建一个mybatis.xml主配置文件和UserMapper.xml映射配置文件。

2.1 mybatis.xml

2.2 UserMapper.xml

3.然后在src\main\java目录下编写相关类,详细说明已在注释中写出

3.1 User类

3.2 Configuration类

3.3 UserMapper类 

4.在src\test\java目录下创建一个测试类CustomerParseTest,用于实现mybatis


1.首先要创建一个maven项目,并导入相关依赖(dom4j)

具体操作点击下面链接就可以看到。

简单地模拟实现Spring解析配置文件并实例化对象_春.生的博客-CSDN博客

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.funny</groupId>
    <artifactId>MyBatis</artifactId>
    <version>1.0-SNAPSHOT</version>

    <dependencies>
        <dependency>
            <groupId>org.mybatis</groupId>
            <artifactId>mybatis</artifactId>
            <version>3.5.10</version>
        </dependency>
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.49</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.24</version>
        </dependency>

        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.13.2</version>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>org.dom4j</groupId>
            <artifactId>dom4j</artifactId>
            <version>2.1.3</version>
        </dependency>
    </dependencies>

</project>

2.在src\main\resources目录下创建一个mybatis.xml主配置文件和UserMapper.xml映射配置文件。

2.1 mybatis.xml

<?xml version="1.0" encoding="UTF-8" ?>
<configuration>
    <environment id="development">
        <dataSource type="POOLED">
            <property name="driver" value="com.mysql.jdbc.Driver"/>
            <property name="url" value="jdbc:mysql://localhost:3306/mydb?useSSL=false&amp;characterEncoding=UTF-8"/>
            <property name="username" value="root"/>
            <property name="password" value=""/>
        </dataSource>
    </environment>
</configuration>

2.2 UserMapper.xml

<mapper namespace="com.funny.mapper.UserMapper">
    <insert id="insert" resultType="com.funny.entity.User">
        insert into tb_user(name, account, password, avatar) values(#{name}, #{account}, #{password}, #{avatar})
    </insert>

    <select id="query" resultType="com.funny.entity.User">
        select * from tb_user
    </select>
</mapper>

3.然后在src\main\java目录下编写相关类,详细说明已在注释中写出

3.1 User类

package com.funny.entity;

import lombok.Data;

import java.io.Serializable;

@Data
public class User implements Serializable {
    private Integer id;
    private String name;
    private String account;
    private String password;
    private String avatar;

}

3.2 Configuration类

package com.funny.config;

import lombok.Data;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
*  用于存储解析配置文件后的相关信息
*/
@Data
public class Configuration {
    private String driver;
    private String url;
    private String username;
    private String password;

    private String className;

    // 执行SQL后返回的类型
    private String resultClassName;
    // 参数名列表
    private List<String> paramNames = new ArrayList<>();
    // 使用map存放sql语句,防止前一条sql语句被后一条sql语句覆盖
    private Map<String,String> sql = new HashMap<>();
}

3.3 UserMapper类 

package com.funny.mapper;

import com.funny.entity.User;

import java.util.List;

/**
* 业务逻辑接口
*/
public interface UserMapper {
    void insert(User user);
    
    List<User> query();
}

4.在src\test\java目录下创建一个测试类CustomerParseTest,用于实现mybatis

package com.funny.parse;

import com.funny.config.Configuration;
import com.funny.entity.User;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.junit.Test;

import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class CustomerParseTest {
    // 用于封装数据
    private Configuration configuration = new Configuration();

    private Connection conn = null;
    
    @Test
    public void Test01() throws DocumentException {
        // 解析主配置文件
        parseConfigureFile("mybatis.xml");

        parseMapperFile("UserMapper.xml");

        //获取数据库连接
        getConnection();

        User user = new User();
        user.setName("刘德华");
        user.setAccount("123");
        user.setPassword("123");

        // 执行添加
        insert(user, User.class, "insert");

        // 全查询
        List<User> select = select(User.class, "select");
        for (User user1 : select) {
            System.out.println(user1);
        }
    }

    // 获取连接数据库
    private void getConnection() {
        try {
            Class.forName(configuration.getDriver());
            conn = DriverManager.getConnection(
                    configuration.getUrl(),
                    configuration.getUsername(),
                    configuration.getPassword());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // 全查询
    private <T> List<T> select(Class<T> clazz, String bs) {
        ArrayList<T> list = new ArrayList<>();
        try {
            // configuration.getSql().get(bs) 获取Map k: bs  sql中的值
            PreparedStatement pst = conn.prepareStatement(configuration.getSql().get(bs));
            // 执行
            ResultSet rs = pst.executeQuery();
            while (rs.next()) {
                Object obj = clazz.newInstance();

                ResultSetMetaData rsmd = rs.getMetaData();
                //取出总列数
                int columnCount = rsmd.getColumnCount();
                //遍历总列数
                for (int i = 1; i <= columnCount; i++) {
                    //获取每列的名称,列名的序号是从1开始的
                    String columnName = rsmd.getColumnName(i);
                    //根据得到列名,获取每列的值
                    Object columnValue = rs.getObject(columnName);

                    Field field = clazz.getDeclaredField(columnName);
                    field.setAccessible(true);
                    field.set(obj, columnValue);
                }
                list.add((T) obj);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return list;
    }

    private void insert(User user, Class<?> clazz, String bs) {
        try {
            PreparedStatement pst = conn.prepareStatement(configuration.getSql().get(bs));

            //循环设置参数
            for (int i = 0; i < configuration.getParamNames().size(); i++) {
                // 获取指定方法
                Method method = clazz.getDeclaredMethod("get" + firstUpper(configuration.getParamNames().get(i)));
                method.setAccessible(true);
                Object value = method.invoke(user);

                pst.setObject(i + 1, value);
            }

            // 执行
            pst.executeUpdate();

        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // 把首字母变为大写
    private String firstUpper(String name) {
        char c = (char) (name.charAt(0) - 32);
        return c + name.substring(1);
    }

    // 解析映射文件
    private void parseMapperFile(String xmlFile) {
        try {
            Element root = getRoot(xmlFile);
            String namespace = root.attributeValue("namespace");
            if (namespace != null && !namespace.equals("")) {
                configuration.setClassName(namespace);
            }

            for (Iterator<Element> it = root.elementIterator(); it.hasNext(); ) {
                Element element = it.next();
                String name = element.getName();
                switch (name) {
                    case "insert":
                        //获取返回类型
                        String resultClassName = element.attributeValue("resultType");
                        configuration.setResultClassName(resultClassName);
                        // 获取 sql 语句
                        String sqlText = element.getTextTrim();
                        int index = sqlText.indexOf("(", 30) + 1;
                        String sql = sqlText.substring(0, index);

                        String subSql = sqlText.substring(index, sqlText.length() - 1);
                        String[] params = subSql.split(",");
                        for (int i = 0; i < params.length; i++) {
                            if (i < params.length - 1) {
                                sql += "?,";
                            } else {
                                sql += "?)";
                            }

                            String paramName = params[i].trim().substring(2, params[i].trim().length() - 1);
                            configuration.getParamNames().add(paramName);
                        }
                        configuration.getSql().put("insert", sql);
                        break;
                    case "select":
                        //获取返回类型
                        resultClassName = element.attributeValue("resultType");
                        configuration.setResultClassName(resultClassName);
                        // 获取 sql 语句
                        sqlText = element.getTextTrim();
                        configuration.getSql().put("select", sqlText);
                        break;
                }
            }

        } catch (DocumentException e) {
            throw new RuntimeException(e);
        }
    }

    // 解析主配置文件
    private void parseConfigureFile(String xmlFile) {
        Element dataSource = null;
        try {
            Element root = getRoot(xmlFile);
            // 获取 environment 节点
            Element environment = root.element("environment");
            // 获取 DataSource 节点
            dataSource = environment.element("dataSource");
            for (Iterator<Element> it = dataSource.elementIterator(); it.hasNext(); ) {
                Element element = it.next();
                String name = element.attributeValue("name");
                String value = element.attributeValue("value");

                switch (name) {
                    case "driver":
                        configuration.setDriver(value);
                        break;
                    case "url":
                        configuration.setUrl(value);
                        break;
                    case "username":
                        configuration.setUsername(value);
                        break;
                    case "password":
                        configuration.setPassword(value);
                        break;
                }
            }
        } catch (DocumentException e) {
            throw new RuntimeException(e);
        }


    }

    // 获取根节点
    private Element getRoot(String xmlFile) throws DocumentException {
        // 读取 mybatis.xml文件
        InputStream is = CustomerParseTest.class.getClassLoader().getResourceAsStream(xmlFile);
        // 创建 SAXReader 对象
        SAXReader reader = new SAXReader();
        Document document = reader.read(is);

        // 获取到根节点
        return document.getRootElement();
    }
}

值得注意的是: 在存放sql语句时,要使用集合进行存放,防止出现前一条sql语句被后一条sql语句覆盖。(因为在解析映射文件时,会有多条sql标签,只要存在都会对sql标签进行解析。如果仅用String类型的参数进行存放sql语句,会出现该参数的值始终为最后解析的sql标签的sql语句,无法对其它sql语句进行处理)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值