mybatis sql自动生成

2 篇文章 0 订阅
2 篇文章 0 订阅

本文提供了一种自动生成sql语句的方法。

1.编辑一个拦截器

package com.jeff.mybatis.autobuild;

import java.sql.Connection;
import java.util.Properties;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.factory.ObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory;

/**
 * 通过拦截<code>StatementHandler</code>的<code>prepare</code>方法,重写sql语句实现物理分页。
 * 签名里要拦截的类型只能是接口。
 * 
 * @author jeff he
 * 
 */
/*
 * @Intercepts是mybatis的注解,@Intercepts用于表明当前的对象是一个Interceptor,而@Signature则表明要拦截的接口、方法以及对应的参数类型。
 */
@Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class }) })
public class AutoBuildInterceptor implements Interceptor {
    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
    /**
     * 正则匹配变量,用于拦截sql方法
     */
    private static final String UPDATE_MATCHER = "^.*update.*";
    private static final String INSERT_MATCHER = "^.*insert.*";
    private static final String BATCH_UPDATE_MATCHER = "^.*batchUpdate.*";
    private static final String BATCH_INSERT_MATCHER = "^.*batchInsert.*";
    private static final String DELETEBYID_MATCHER = "^.*deleteById.*";
    private static final String GETBYID_MATCHER = "^.*getById.*";
    private static final String LIST_MATCHER = "^.*list.*";
    private static final String FINDBYPAGE_MATCHER = "^.*findByPage";

    public static void main(String[] args) {
        System.out.println("batchUpdatexx".matches(BATCH_UPDATE_MATCHER));
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // System.out.println("auto");
        StatementHandler statementHandler = (StatementHandler) invocation
                .getTarget();
        MetaObject metaStatementHandler = MetaObject.forObject(
                statementHandler, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY);
        // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类)
        while (metaStatementHandler.hasGetter("h")) {
            Object object = metaStatementHandler.getValue("h");
            metaStatementHandler = MetaObject.forObject(object,
                    DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY);
        }
        // 分离最后一个代理对象的目标类
        while (metaStatementHandler.hasGetter("target")) {
            Object object = metaStatementHandler.getValue("target");
            metaStatementHandler = MetaObject.forObject(object,
                    DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY);
        }
        /**
         * 上面的代码基本上是固定的
         */
        MappedStatement mappedStatement = (MappedStatement) metaStatementHandler
                .getValue("delegate.mappedStatement");
        String sqlId = mappedStatement.getId();

        BoundSql boundSql = (BoundSql) metaStatementHandler
                .getValue("delegate.boundSql");
        // 获得参数
        Object parameterObject = boundSql.getParameterObject();
        String sql = boundSql.getSql();
        // 重写sql
        if (sqlId.matches(INSERT_MATCHER))
            sql = SqlBulider.buildLogInsertSql(SqlBulider
                    .buildRawInsertSql(parameterObject));
        else if (sqlId.matches(UPDATE_MATCHER))
            sql = SqlBulider.buildLogUpdateSql(SqlBulider
                    .buildRawUpdateSql(parameterObject));
        else if (sqlId.matches(GETBYID_MATCHER))
            sql = SqlBulider.buildGetByIdSql(parameterObject);
        else if (sqlId.matches(DELETEBYID_MATCHER))
            sql = SqlBulider.buildDeleteByIdSql(parameterObject);
        else if (sqlId.matches(LIST_MATCHER)
                || sqlId.matches(FINDBYPAGE_MATCHER))
            sql = SqlBulider.buildQuerySql(parameterObject);
        else if (sqlId.matches(BATCH_INSERT_MATCHER))
            sql = SqlBulider.buildBatchInsertSql(sql);
        else if (sqlId.matches(BATCH_UPDATE_MATCHER))
            sql = SqlBulider.buildBatchUpdateSql(sql);
        metaStatementHandler.setValue("delegate.boundSql.sql", sql);
        // 将执行权交给下一个拦截器
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        // 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {
    }

}

2.重写sql的sql构造类

package com.jeff.mybatis.autobuild;

import java.lang.reflect.Field;
import java.text.SimpleDateFormat;
import java.util.UUID;

import org.apache.ibatis.session.SqlSessionException;

public class SqlBulider {

    private static SimpleDateFormat sdf = new SimpleDateFormat(
            "yyyy-MM-dd HH:mm:ss");

    // 重写插入语句不含log
    public static String buildRawInsertSql(Object obj)
            throws IllegalArgumentException, IllegalAccessException {
        Class<?> clz = obj.getClass();//获取实体类
        String clzName = clz.getSimpleName(); //获取操作的表名/实体类名
        String _clzName = NameConverter.conver(clzName);//命名转换
        StringBuilder rawInsertSql = new StringBuilder("insert into `"
                + _clzName + "` ");
        StringBuilder columnsStr = new StringBuilder("(");
        StringBuilder valuesStr = new StringBuilder("(");
        // 获取本身的属性
        Field[] localFields = clz.getDeclaredFields();
        // 获取继承的属性(必须为public的)
        Field[] inheritFields = clz.getFields();
        String _name;
        String fName;
        Object value;
        for (Field field : inheritFields) {//遍历继承的属性
            /**
             * 抑制Java的访问控制检查 如果不加上上面这句,将会Error: TestPrivate can not access a
             * member of class PrivateClass with modifiers "private"
             */
            field.setAccessible(true);
            fName = field.getName();//属性名
            _name = NameConverter.conver(fName);//转换命名方式
            value = field.get(obj);//获取属性值
            // 如果id为""就替换掉,easyui的form表单提交会默认提交个""进来
            if (fName.equals("id") && value.equals("")){
                value = UUID.randomUUID().toString();
            }       
            columnsStr.append(_name + ",");
            // 跳过null
            if (value != null) {
                if (field.getType() == String.class && value != null)
                    valuesStr.append("'" + value + "',");
                else if (field.getType() == java.util.Date.class)
                    valuesStr.append("'" + sdf.format((java.util.Date) value)
                            + "'");
                else
                    valuesStr.append(value + ",");
            }
        }
        for (Field field : localFields) {
            /**
             * 抑制Java的访问控制检查 如果不加上上面这句,将会Error: TestPrivate can not access a
             * member of class PrivateClass with modifiers "private"
             */
            field.setAccessible(true);
            fName = field.getName();
            _name = NameConverter.conver(fName);
            value = field.get(obj);
            columnsStr.append(_name);
            if (field.getType() == String.class && value != null)
                valuesStr.append("'" + value + "'");
            else if (field.getType() == java.util.Date.class && value != null)
                valuesStr
                        .append("'" + sdf.format((java.util.Date) value) + "'");
            else
                valuesStr.append(value);
            if (field == localFields[localFields.length - 1]) {
                columnsStr.append(")");
                valuesStr.append(")");
            } else {
                columnsStr.append(",");
                valuesStr.append(",");
            }
        }
        rawInsertSql.append(columnsStr + " values " + valuesStr);
        return rawInsertSql.toString();
    }

    // 构建原生更新语句
    public static String buildRawUpdateSql(Object obj)
            throws IllegalArgumentException, IllegalAccessException,
            NoSuchFieldException, SecurityException {
        Class<?> clz = obj.getClass();
        String clzName = clz.getSimpleName();
        String _clzName = NameConverter.conver(clzName);
        StringBuilder rawUpdateSql = new StringBuilder("update  `" + _clzName
                + "` set");
        // 获取本身的属性
        Field[] fields = clz.getDeclaredFields();
        // 获取继承的属性(必须为public的)
        Field[] inheritFields = clz.getFields();
        String _name;
        String fName;
        Object value;
        for (Field field : inheritFields) {
            // 跳过主键
            if (field.getName().equals("id"))
                continue;
            /**
             * 抑制Java的访问控制检查 如果不加上上面这句,将会Error: TestPrivate can not access a
             * member of class PrivateClass with modifiers "private"
             */
            field.setAccessible(true);
            fName = field.getName();
            _name = NameConverter.conver(fName);
            value = field.get(obj);
            if (field.getType() == String.class && value != null)
                rawUpdateSql.append(" " + _name + " = '" + value + "',");
            else if (field.getType() == java.util.Date.class && value != null)
                rawUpdateSql.append(" " + _name + " = '"
                        + sdf.format((java.util.Date) value) + "',");
            else if (value != null)
                rawUpdateSql.append(" " + _name + " = " + value + ",");
        }
        for (Field field : fields) {
            // 跳过主键
            if (field.getName().equals("id"))
                continue;
            /**
             * 抑制Java的访问控制检查 如果不加上上面这句,将会Error: TestPrivate can not access a
             * member of class PrivateClass with modifiers "private"
             */
            field.setAccessible(true);
            fName = field.getName();
            _name = NameConverter.conver(fName);
            value = field.get(obj);
            // 跳过null
            if (value != null) {
                if (field.getType() == String.class)
                    rawUpdateSql.append(" " + _name + " = '" + value + "',");
                else if (field.getType() == java.util.Date.class)
                    rawUpdateSql.append(" " + _name + " = '"
                            + sdf.format((java.util.Date) value) + "',");
                else
                    rawUpdateSql.append(" " + _name + " = " + value + ",");
            }
        }
        // 删除最后一个多余的逗号
        rawUpdateSql.deleteCharAt(rawUpdateSql.length() - 1);
        // 获取主键
        Field idField = clz.getField("id");
        Object id = idField.get(obj);
        if (idField.getType() == String.class)
            rawUpdateSql.append(" where id = '" + id + "'");
        else
            rawUpdateSql.append(" where id = " + id);
        return rawUpdateSql.toString();
    }

    // 构建删除语句
    public static String buildDeleteByIdSql(Object id) {
        String clzName = MybatisContext.getClzName();
        String _clzName = NameConverter.conver(clzName);
        StringBuilder deleteByIdSql = new StringBuilder("delete from `"
                + _clzName + "` where id = ");
        if (id.getClass() == String.class && id != null)
            deleteByIdSql.append("'" + id + "'");
        else if (id != null)
            deleteByIdSql.append(id);
        return deleteByIdSql.toString();
    }

    // 构建getById语句
    public static String buildGetByIdSql(Object id) {
        String clzName = MybatisContext.getClzName();
        String _clzName = NameConverter.conver(clzName);
        StringBuilder getByIdSql = new StringBuilder("select * from `"
                + _clzName + "` where id = ");
        if (id == null)
            throw new SqlSessionException("Id 不能为 null");
        else if (id.getClass() == String.class)
            getByIdSql.append("'" + id + "'");
        else
            getByIdSql.append(id);
        return getByIdSql.toString();
    }

    // 构建默认query语句
    public static String buildQuerySql(Object obj)
            throws IllegalArgumentException, IllegalAccessException {
        Class<?> clz = obj.getClass();
        String clzName = clz.getSimpleName();
        String _clzName = NameConverter.conver(clzName);
        StringBuilder querySql = new StringBuilder("select * from `" + _clzName
                + "`  where id is not null");
        // 获取本身的属性
        Field[] fields = clz.getDeclaredFields();
        // 获取继承的属性(必须为public的)
        Field[] inheritFields = clz.getFields();
        String _name;
        String fName;
        Object value;
        for (Field field : inheritFields) {
            /**
             * 抑制Java的访问控制检查 如果不加上上面这句,将会Error: TestPrivate can not access a
             * member of class PrivateClass with modifiers "private"
             */
            field.setAccessible(true);
            fName = field.getName();
            _name = NameConverter.conver(fName);
            value = field.get(obj);
            if (field.getType() == String.class && value != null)
                querySql.append(" and " + _name + " = '" + value + "'");
            else if (value != null)
                querySql.append(" and " + _name + " = " + value);
        }
        for (Field field : fields) {
            field.setAccessible(true);
            fName = field.getName();
            _name = NameConverter.conver(fName);
            value = field.get(obj);
            if (field.getType() == String.class && value != null)
                querySql.append(" and " + _name + " = '" + value + "'");
            else if (value != null)
                querySql.append(" and " + _name + " = " + value);
        }
        return querySql.toString();
    }

    // 重写插入语句
    public static String buildLogInsertSql(String sql) {
        StringBuilder insertSql = new StringBuilder(sql);
        String createId = MybatisContext.getUserId();
        insertSql.insert(insertSql.indexOf(")"), ",create_dt,create_id");
        insertSql.insert(insertSql.lastIndexOf(")"), ",now(),'" + createId
                + "'");
        return insertSql.toString();
    }

    // 重写更新语句
    public static String buildLogUpdateSql(String sql) {
        StringBuilder updateSql = new StringBuilder(sql);
        String updateId = MybatisContext.getUserId();
        if (updateId != null)
            updateSql.insert(updateSql.indexOf("where"),
                    ",update_dt=now(),update_id='" + updateId + "' ");
        else
            updateSql.insert(updateSql.indexOf("where"),
                    ",update_dt=now(),update_id=" + updateId + " ");
        return updateSql.toString();
    }

    // 为批量更新加入log字段
    public static String buildBatchUpdateSql(String sql) {
        sql = sql.toLowerCase();
        StringBuilder updateSql = new StringBuilder(sql);
        String updateId = MybatisContext.getUserId();
        String addStr;
        if (updateId != null)
            addStr = ",update_dt=now(),update_id='" + updateId + "' ";
        else
            addStr = ",update_dt=now(),update_id= null ";
        int fromIndex = 0;
        int index = 0;
        while (fromIndex < updateSql.length()
                && (index = updateSql.indexOf("where", fromIndex)) != -1) {
            updateSql.insert(index, addStr);
            fromIndex = index + addStr.length() + 1;
        }
        return updateSql.toString();
    }

    // 为批量插入加入log字段
    public static String buildBatchInsertSql(String sql) {
        sql = sql.toLowerCase().replaceAll("uuid\\(\\)", "uuid");
        StringBuilder insertSql = new StringBuilder(sql);
        String createId = MybatisContext.getUserId();
        int fromIndex = 0;
        int index2 = 0;
        int index1 = 0;
        String addStr1 = ",create_dt,create_id";
        String addStr2 = ",now(),'" + createId + "'";
        if (createId != null)
            addStr2 = ",now(),'" + createId + "'";
        else
            addStr2 = ",now(),null";
        index1 = insertSql.indexOf(")", fromIndex);
        insertSql.insert(index1, addStr1);
        fromIndex = index1 + addStr1.length() + 1;
        while (fromIndex < insertSql.length()
                && (index2 = insertSql.indexOf(")", fromIndex)) != -1) {
            insertSql.insert(index2, addStr2);
            fromIndex = index2 + addStr2.length() + 1;
        }
        return insertSql.toString().replaceAll("uuid", "uuid()");
    }
}

3.MybatisContext—用于记录当前线程用户基本信息

package com.jeff.mybatis.autobuild;

public class MybatisContext {

    private static ThreadLocal<String> userId = new ThreadLocal<String>();

    private static ThreadLocal<String> ip = new ThreadLocal<String>();

    private static ThreadLocal<String> url = new ThreadLocal<String>();

    private static ThreadLocal<String> clzName = new ThreadLocal<String>();

    public static void clearContext() {
        userId.remove();
        ip.remove();
        url.remove();
        clzName.remove();
    }

    public static String getClzName() {
        return clzName.get();
    }

    public static void setClzName(String _clzName) {
        clzName.set(_clzName);
    }

    public static void removeClzName() {
        clzName.remove();
    }

    public static String getUserId() {
        return userId.get();
    }

    public static void setUserId(String _userId) {
        userId.set(_userId);
    }

    public static void removeUserId() {
        userId.remove();
    }

    public static String getIp() {
        return ip.get();
    }

    public static void setIp(String _ip) {
        ip.set(_ip);
    }

    public static void removeIp() {
        ip.remove();
    }

    public static String getUrl() {
        return url.get();
    }

    public static void setUrl(String _url) {
        url.set(_url);
    }

    public static void removeUrl() {
        url.remove();
    }
}

4.命名转换类

package com.jeff.mybatis.autobuild;

/**
 * 驼峰命名转下划线
 * @author jeff he
 *
 */
public class NameConverter {

    public static String conver(String name) {
        StringBuilder result = new StringBuilder();
        if (name != null && name.length() > 0) {
            // 将第一个字符处理成小写
            result.append(name.substring(0, 1).toLowerCase());
            // 循环处理其余字符
            for (int i = 1; i < name.length(); i++) {
                String s = name.substring(i, i + 1);
                // 在大写字母前添加下划线
                if (s.equals(s.toUpperCase())
                        && !Character.isDigit(s.charAt(0))) {
                    result.append("_");
                }
                // 其他字符直接转成大写
                result.append(s.toLowerCase());
            }
        }
        return result.toString();
    }

}

5.可以写一个baseMapper如下,让所有maper.java继承,便可以使所有
mapper有这几个通用方法,具体sql是由上面代码自动生成。

package com.jeff.mapper;

import java.io.Serializable;
import java.util.List;

public interface BaseMapper<T, ID extends Serializable> {

    int deleteById(ID id);

    int insert(T t);

    int update(T t);

    T getById(ID id);

    List<T> list(T t);

    List<T> findByPage(T t);

}
<mapper namespace="com.jeff.mapper.UserMapper">

    <resultMap type="User" id="userRM">
    </resultMap>

    <!-- 上面是一个整体 -->

    <insert id="insert" parameterType="User">
    </insert>

    <delete id="deleteById" parameterType="String">
    </delete>

    <update id="update" parameterType="User">
    </update>

    <select id="getById" parameterType="String" resultType="User">
    </select>

    <select id="list" parameterType="User" resultMap="userRM">
    </select>

    <select id="findByPage" parameterType="User" resultMap="userRM">
    </select>
</mapper>

参考网址:深入浅出Mybatis-sql自动生成

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值