手写mybatis(一):增删改查(CRUD)实现

文末附有源码地址

博客为代码完成后才开始整理,展示的代码都是最终代码

介绍:

个人学习,代码写的不够好,有点乱
通过学习mybatis源码模拟实现Mybatis(在完善)
已实现:
crud(实现XML配置,注解配置写过一点,原理差不多不写了)、
事务、
数据库连接池、
动态sql(一部分,原理没问题了,以后有时间再完善)、
一级二级缓存

一、如何着手

我们从这段代码看起:

//1.读取配置文件
        InputStream in = Resources.getResourceAsStream("SqlMapConfig.xml");
        //2.创建SqlSessionFactory工厂
        SqlSessionFactoryBuilder builder = new SqlSessionFactoryBuilder();
        SqlSessionFactory factory;
        factory = builder.build(in);
        //3.使用工厂生产SqlSession对象
        SqlSession session = factory.openSession(true);

        //SqlSession session = SessionUtil.getSession();
        //4.使用SqlSession创建Dao接口的代理对象
        IUserDao userDao = session.getMapper(IUserDao.class);
        //5.使用代理对象执行方法
      //  List<User> user = userDao.findAll();
        //User u = new User();
        //u.setSex("男");
        //User user = userDao.findById("41");
      //  List<User> user = userDao.findSome("%王%","男");
        List<User> user = userDao.findSomeOgnl(new User(105,"%王%",new Date(1920,11,23),null,"江西"));
       // userDao.insertUser(new User(105,"小q",new Date(1920,11,23),"男","江西"));
    //    userDao.deletebyId("49");
//        long time = System.currentTimeMillis();
//        java.sql.Date date = new java.sql.Date(time);
//        java.sql.Timestamp timestamp = new Timestamp(time);
        //userDao.update(new User(42,"小花k",timestamp,"女","江西"));
       // System.out.println(user);
        for(User u : user){
            System.out.println(u);
        }
        //6.释放资源
        //session.commit();
        session.close();

这段代码展示了如何调用使用mybaits执行sql,期间夹杂着我的一些测试用例。
具体流程我已经在注释上写的很清晰了。用工厂模式生产sqlsession,再创建代理对象。我们所需要做的是将这些类实现出来。

二、配置文件的读取

读取解析xml文件有许多方式如:DOM、DOM4J、SAX。本人使用的是dom4j。

  1. 主配置文件的读取
    首先来看一下mybatisXML配置:
<?xml version="1.0" encoding="UTF-8"?>
<!-- mybatis的主配置文件 -->
<configuration>
    <!-- 配置环境 -->
    <environments default="mysql">
        <!-- 配置mysql的环境-->
        <environment id="mysql">
            <!-- 配置事务的类型-->
            <transactionManager type="JDBC"></transactionManager>
            <!-- 配置数据源(连接池) -->
            <dataSource type="POOLED">
                <!-- 配置连接数据库的4个基本信息 -->
                <property name="driver" value="com.mysql.cj.jdbc.Driver"/>
                <property name="url" value="jdbc:mysql://localhost:3306/eesy_mybatis?serverTimezone=GMT%2B8&amp;useSSL=false"/>
                <property name="username" value="root"/>
                <property name="password" value=""/>
            </dataSource>
        </environment>
    </environments>

    <!-- 指定映射配置文件的位置,映射配置文件指的是每个dao独立的配置文件 -->
    <mappers>
        <mapper resource="com/baiye/dao/IUserDao.xml"/>
<!--        <mapper class="com.baiye.www.dao.IUserDao"/>-->
    </mappers>
</configuration>

我们要将其解析读取,下面是代码:

public static Configuration loadConfiguration(InputStream in) {
        Configuration config = new Configuration();
        SAXReader saxReader = new SAXReader();
        try {
            Document document = saxReader.read(in);
            Element root = document.getRootElement();
            //xpath '//'表示在当先所选节点查找
            List dataSourceNode = root.selectNodes("//dataSource");
            String type = ((Element) dataSourceNode.get(0)).attributeValue("type");
            config.setDataSourceType(type);
            List nodes = root.selectNodes("//property");
            Iterator iterator = nodes.iterator();
            while (iterator.hasNext()) {
                Element element = (Element) iterator.next();
                String name = element.attributeValue("name");
                String value = element.attributeValue("value");
                if ("driver".equals(name)) {
                    config.setDriver(value);
                } else if ("url".equals(name)) {
                    config.setUrl(value);
                } else if ("username".equals(name)) {
                    config.setUsername(value);
                } else if ("password".equals(name)) {
                    config.setPassword(value);
                } else {
                    throw new RuntimeException("xml sql config error!");
                }
            }

            //mappers
            List mapperList = root.selectNodes("//mappers/mapper");
            Iterator mapperIterator = mapperList.iterator();
            Map<String, Mapper> mappers = new HashMap<>();
            while (mapperIterator.hasNext()) {
                Element element = (Element) mapperIterator.next();
                Attribute attribute = element.attribute("resource");
                if (attribute != null) {
                    //使用xml
                    String classPath = attribute.getValue();
                    Map<String, Mapper> mapper = loadMapperXMLConfiguration(classPath);
                    mappers.putAll(mapper);
                } else if (element.attribute("class") != null) {
                    //使用注解
                    String classPath = element.attribute("class").getValue();
                    Map<String, Mapper> mapper = loadMapperAnnotation(classPath);
                    mappers.putAll(mapper);
                } else {
                    throw new RuntimeException("xml mapper config error");
                }
            }
            config.setMappers(mappers);
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("loadConfiguration error!");
        }

        if (config.getDataSourceType().equals("POOLED")) {
            config.setEnvironment(new Environment(null, null, new PooledDataSourceFactory(config).getDataSource()));
        } else if (config.getDataSourceType().equals("UNPOOLED")) {
            config.setEnvironment(new Environment(null, null, new UnpooledDataSourceFactory(config).getDataSource()));
        }
        return config;
    }
  1. mapper配置文件的读取
    配置文件如下:
<?xml version="1.0" encoding="UTF-8"?>
<mapper namespace="com.baiye.www.dao.IUserDao">
    <!--配置查询所有-->
    <select id="findById" resultType="com.baiye.www.domain.User" parameterType="Integer">
        select * from user where id= #{id}
    </select>

    <select id="findAll" resultType="com.baiye.www.domain.User">
        select * from user
    </select>

    <select id="findSome" resultType="com.baiye.www.domain.User">
        select * from `user` where username LIKE #{username} and sex = #{sex}
    </select>
    <select id="findSomeOgnl" resultType="com.baiye.www.domain.User" parameterType="com.baiye.www.domain.User">
        select * from `user` where username LIKE #{username} and username LIKE #{username}
        <if test="sex != null">
            and sex = #{sex}
        </if>
    </select>
    <insert id="insertUser" parameterType="com.baiye.www.domain.User">
        insert into user(username,birthday,sex,address) values(#{username},#{birthday},#{sex},#{address})
    </insert>
    <delete id="deletebyId" parameterType="Integer">
        delete from user where id = #{id}
    </delete>
    <update id="update" parameterType="com.baiye.www.domain.User">
        update user set username=#{username},birthday=#{birthday},sex=#{sex} where id=#{id}
    </update>
</mapper>

代码如下,分别有读取注解配置的和读取xml配置的:

/**
     * 注解配置 根据传入的全限定类名封装mapper(sql,resultType)
     *
     * @param classPath
     * @return
     */
    private static Map<String, Mapper> loadMapperAnnotation(String classPath) {
        Map<String, Mapper> mappers = new HashMap<String, Mapper>();
        Mapper mapper = new Mapper();
        try {
            Class daoClass = Class.forName(classPath);
            Method[] methods = daoClass.getMethods();
            for (Method method : methods) {
                if (method.isAnnotationPresent(Select.class)) {
                    Select selectAnnotation = method.getAnnotation(Select.class);
                    String sql = selectAnnotation.value();
                    mapper.setSql(sql);
                    String methodName = method.getName();
                    Type type = method.getGenericReturnType();
                    //判断type是不是参数化的类型
                    if (type instanceof ParameterizedType) {
                        ParameterizedType ptype = (ParameterizedType) type;
                        Type[] types = ptype.getActualTypeArguments();
                        Class domainClass = (Class) types[0];
                        //获取domainClass的类名
                        String resultType = domainClass.getName();
                        mapper.setResultType(resultType);
                    }
                    String key = classPath + "." + methodName;
                    mappers.put(key, mapper);
                }
            }

        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return mappers;
    }


    /**
     * xml配置 根据传入的全限定类名封装mapper(sql,resultType)
     *
     * @param xmlPath
     * @return
     */
    private static Map<String, Mapper> private static Map<String, Mapper> loadMapperXMLConfiguration(String xmlPath) {
        Map<String, Mapper> mappers = new HashMap<String, Mapper>();
        InputStream in = null;
        try {
            in = Resources.getResourceAsStream(xmlPath);
            SAXReader saxReader = new SAXReader();
            Document document = saxReader.read(in);
            Element root = document.getRootElement();
            String namespace = root.attributeValue("namespace");

            List selectNodes = root.selectNodes("//select");
            Map<String, Mapper> selectMappers = NodeToMapper(namespace,selectNodes, SqlCommandType.SELECT);

            List insertNodes = root.selectNodes("//insert");
            Map<String, Mapper> insertMappers = NodeToMapper(namespace,insertNodes, SqlCommandType.INSERT);

            List updateNodes = root.selectNodes("//update");
            Map<String, Mapper> updateMappers = NodeToMapper(namespace,updateNodes, SqlCommandType.UPDATE);


            List deleteNodes = root.selectNodes("//delete");
            Map<String, Mapper> deleteMappers =NodeToMapper(namespace,deleteNodes,SqlCommandType.DELETE);

            mappers.putAll(selectMappers);
            mappers.putAll(insertMappers);
            mappers.putAll(updateMappers);
            mappers.putAll(deleteMappers);
            in.close();


        } catch (Exception e) {
            e.printStackTrace();
        }

        return mappers;
    }

    private static Map<String, Mapper> NodeToMapper(String namespace, List list, SqlCommandType type){
        Map<String, Mapper> mappers = new HashMap<String, Mapper>();
        Iterator iterator = list.iterator();
        while (iterator.hasNext()) {
            Element element = (Element) iterator.next();
            String id = element.attributeValue("id");
            String resultType = element.attributeValue("resultType");
            String parameterType = element.attributeValue("parameterType");
            String resultMap = element.attributeValue("resultMap");
            String key = namespace + "." + id;
            String sql = element.getText();
            Mapper mapper = new Mapper(key,sql, element, resultType, parameterType, resultMap,type);
            mappers.put(key, mapper);
        }
        return mappers;
    }

标记Sql类型枚举类

public enum SqlCommandType {
    UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
}

读取完成后我们将信息封装成Configuration类和Mapper类。

public class Configuration {
    protected Environment environment;
    private String dataSourceType;
    private String driver;
    private String url;
    private String username;
    private String password;
    private Cache perpetualCache = new PerpetualCache();
    private boolean enableCache;
    //全限定类名+方法名,mapper
    private Map<String, Mapper> mappers = new HashMap<String, Mapper>();
    ......
    }
public class Mapper {

    private String sql;
    private Element element;
    private String resultType;
    private String parameterType;
    private String resultMap;
    ......
    }

三、工厂类的创建

首先创建SqlSessionFactoryBuilder类,根据读取的配置文件创建工厂,支持多种参数。

public class SqlSessionFactoryBuilder {
    private SqlSessionFactory build(InputStream inputStream, String environment, Properties properties) {
        Configuration configuration = XMLConfigBuilder.loadConfiguration(inputStream);
        return build(configuration);
    }

    public SqlSessionFactory build(InputStream inputStream) {
        return build(inputStream, null, null);
    }

    public SqlSessionFactory build(Configuration configuration) {

        return new DefaultSqlSessionFactory(configuration);
    }
}

根据mybatis源码的设计,我们得分别创建SqlSessionFactorySqlSession接口,并实现对应的默认实现类:DefaultSqlSessionDefaultSqlSession

public interface SqlSessionFactory {
    SqlSession openSession();

    SqlSession openSession(boolean autoCommit);
}
public interface SqlSession {

    <T> T getMapper(Class<T> daoInterfaceClass);

    <T> T selectOne(String mapperName, Object[] parameter);

    <T> List<T> selectList(String mapperName, Object[] parameter);

    int insert(String mapperName, Object[] parameter);

    int update(String mapperName, Object[] parameter);

    int delete(String mapperName, Object[] parameter);

    void clearCache();

    /**
     * 若未提交直接关闭,会回滚
     */
    void close();

    void rollback();

    void commit();

}

DefaultSqlSession中包含配置涉及到事务等内容:

public class DefaultSqlSession implements SqlSession {
    private final boolean autoCommit;
    protected Configuration configuration;
    private Executor executor;
    private boolean dirty;
    .....
    }

四、代理对象

下面这个方法为关键:

@Override
    public <T> T getMapper(Class<T> daoInterfaceClass) {
        return (T) Proxy.newProxyInstance(daoInterfaceClass.getClassLoader(), new Class[]{daoInterfaceClass}, new MapperProxy(this));
    }

这个方法得到代理对象。

下面为MapperProxy类:

package com.baiye.www.mybaits.proxy;

import com.baiye.www.mybaits.annotation.Param;
import com.baiye.www.mybaits.sqlsession.SqlSession;

import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

/**
 * Created with IntelliJ IDEA.
 *
 * @Author: baiye
 * @Date: 2021/07/13/16:54
 * @Description:
 */
public class MapperProxy implements InvocationHandler {
    //map的key是全限定类名+方法名
    //private Map<String, Mapper> mappers;
    private SqlSession sqlSession;


    public MapperProxy(SqlSession sqlSession) {
        this.sqlSession = sqlSession;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        System.out.println(method.getDeclaringClass().getName() + "." + method.getName());
        //最终还是将执行方法转给 sqlSession,因为 sqlSession 里面封装了 Executor
        //根据调用方法的类名和方法名以及参数,传给 sqlSession 对应的方法

        if (args != null && args.length > 1) {
            Annotation[][] annotation = method.getParameterAnnotations();
            for (int i = 0; i < args.length; i++) {
                Map<String, String> map = new HashMap(1);
                Param anno = (Param) annotation[i][0];
                map.put(anno.value(), args[i] + "");
                args[i] = map;
            }
        }
        String namespace=method.getDeclaringClass().getName() + "." + method.getName();
        SqlCommandType sqlType = sqlSession.getConfiguration().getMappers().get(namespace).getSqlType();
        if (sqlType.equals(SqlCommandType.SELECT)) {
            if (Collection.class.isAssignableFrom(method.getReturnType())) {
                return sqlSession.selectList(namespace, args);
            }
            return sqlSession.selectOne(namespace, args);
        } else if (sqlType.equals(SqlCommandType.UPDATE)||sqlType.equals(SqlCommandType.DELETE)||sqlType.equals(SqlCommandType.INSERT)) {
            return sqlSession.update(namespace, args);
        } else {
            throw new RuntimeException("no such SqlType");
        }

    }
}

五、执行sql语句

代理对象调用sqlsession的方法。最终session调用Executor执行sql语句。

下面为构造sql语句写的类,没有参考mybatis极为粗糙和简单:

public static String paramToSql(String originalSql, Object[] object) throws InvocationTargetException, IllegalAccessException, IntrospectionException {
        int paramCount = StringUtil.getTargetStringNum(originalSql, "#");
        //没有参数
        if (paramCount < 1) {
            return originalSql;
        } else if (paramCount == 1) { //只有一个参数即只有一个基本数据类型,转化为字符串
            String regex = "#\\{([^}])*}";
            //将 sql 语句中的 #{*} 替换为 ?
            return originalSql.replaceAll(regex, "\"" + object[0] + "" + "\"");
        } else if (paramCount > 1 && object.length == 1) { //多个参数,且传实体对象
            Class obj = object[0].getClass();
            Field[] declaredFields = obj.getDeclaredFields();
            for (Field field : declaredFields) {
                String fieldName = field.getName();
                PropertyDescriptor propertyDescriptor = new PropertyDescriptor(fieldName, obj);
                Method readMethod = propertyDescriptor.getReadMethod();
                Object o = readMethod.invoke(object[0]);
                if (o != null) {
                    originalSql = originalSql.replace("#{" + fieldName + "}", "\"" + o + "" + "\"");
                } else {
                    originalSql = originalSql.replace("#{" + fieldName + "}", "null");
                }
            }
            return originalSql;
        } else { //多个参数,传的也是多个基本类型
            for (Object o : object) {
                //前面已经是String,String,防止意外改为String,Object
                HashMap<String, Object> map = (HashMap<String, Object>) o;
                for (Map.Entry<String, Object> entry : map.entrySet()) {
                    String regex = "#\\{" + entry.getKey() + "}";
                    //将 sql 语句中的 #{*} 替换为实际参数
                    if (entry.getValue() instanceof Integer || entry.getValue().equals("1 OR 1")) {
                        originalSql = originalSql.replaceAll(regex, entry.getValue() + "");
                    } else {
                        if (entry.getValue() != null) {
                            originalSql = originalSql.replaceAll(regex, "\"" + entry.getValue() + "" + "\"");
                        } else {
                            originalSql = originalSql.replaceAll(regex, "null");
                        }

                    }
                }
            }
            return originalSql;
        }
    }

六、结果的封装

执行查询的sql得到的结果通过在Executor以下代码封装:

while (resultSet.next()) {
                E pojo = (E) pojoClass.newInstance();
                ResultSetMetaData metaData = resultSet.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String columnName = metaData.getColumnName(i);
                    Object value = resultSet.getObject(columnName);
                    columnName = StringUtil.underlineToHump(columnName);

                    PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, pojoClass);
                    Method writeMethod = propertyDescriptor.getWriteMethod();
                    writeMethod.invoke(pojo, value);
                }
                list.add(pojo);
            }

项目地址

github:https://github.com/Alice-175/Mybaits

gitee:https://gitee.com/alice-175/Mybaits

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值