Mybatis interceptor 获取clickhouse最终执行的sql

MybatisClickHouseGetSqlInterceptor.java

import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import ru.yandex.clickhouse.ClickHouseUtil;
import ru.yandex.clickhouse.util.ClickHouseArrayUtil;
import ru.yandex.clickhouse.util.apache.StringUtils;
import ru.yandex.clickhouse.util.guava.StreamUtils;

import java.io.InputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.*;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.Date;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
@Slf4j
public class MybatisClickHouseGetSqlInterceptor implements Interceptor {



    public static void main(String[] args) {
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        simpleDateFormat.setTimeZone(TimeZone.getDefault());
        String format = simpleDateFormat.format(new Date());
        System.out.println(format);
    }

    @Override public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        //0.sql参数获取
        Object parameter = null;
        if (invocation.getArgs().length > 1) {
            parameter = invocation.getArgs()[1];
        }
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        //1.方法没有参数或者第一个参数为false
        if (parameter == null || (parameter instanceof ParamMap && (((ParamMap) parameter).size() == 0 || ((ParamMap) parameter).get("param1") != Boolean.TRUE)) || (parameter instanceof Boolean && parameter != Boolean.TRUE) || (!(parameter instanceof Boolean) && !(parameter instanceof ParamMap))) {
            //集群提交任务
            return invocation.proceed();
        }
        Configuration configuration = mappedStatement.getConfiguration();
        //2. 返回sql语句
        Object[] params = getParams(configuration, boundSql);
        String rawsql = boundSql.getSql();
        String sql = PreparedStatementParser.parse(rawsql).buildSql(params);
        return Arrays.asList(sql);
    }

    @Override public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override public void setProperties(Properties properties) {
    }
    private static Object[] getParams(Configuration configuration, BoundSql boundSql) {
        ArrayList<Object> list = new ArrayList<>();
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (!parameterMappings.isEmpty() && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                list.add(parameterObject);
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        list.add(obj);
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        list.add(obj);
                    } else {
                        list.add("缺失");
                    }//打印出缺失,提醒该参数缺失并防止错位
                }
            }
        }
        return list.toArray();
    }
    /**
     * Parser for clickhouse JDBC SQL Strings
     * <p>
     * Tries to extract query parameters in a way that is usable for (batched)
     * prepared statements.
     */
    public final static class PreparedStatementParser {
        private static final Pattern VALUES = Pattern.compile("(?i)INSERT\\s+INTO\\s+.+VALUES\\s*\\(");
        private static final String PARAM_MARKER = "?";
        private final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd");
        private final SimpleDateFormat dateTimeFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        private final List<List<String>> parameters;
        private final List<String> parts;
        private boolean valuesMode;

        private PreparedStatementParser() {
            parameters = new ArrayList<List<String>>();
            parts = new ArrayList<String>();
            valuesMode = false;
            initTimeZone(TimeZone.getDefault());
        }

        public static PreparedStatementParser parse(String sql) {
            if (StringUtils.isBlank(sql)) {
                throw new IllegalArgumentException("SQL may not be blank");
            }
            PreparedStatementParser parser = new PreparedStatementParser();
            parser.parseSQL(sql);
            return parser;
        }

        private static String typeTransformParameterValue(String paramValue) {
            if (paramValue == null) {
                return null;
            }
            if (Boolean.TRUE.toString().equalsIgnoreCase(paramValue)) {
                return "1";
            }
            if (Boolean.FALSE.toString().equalsIgnoreCase(paramValue)) {
                return "0";
            }
            if ("NULL".equalsIgnoreCase(paramValue)) {
                return "\\N";
            }
            return paramValue;
        }

        private void initTimeZone(TimeZone timeZone) {
            dateTimeFormat.setTimeZone(timeZone);
            dateFormat.setTimeZone(timeZone);
        }


        private void reset() {
            parameters.clear();
            parts.clear();
            valuesMode = false;
        }

        private void parseSQL(String sql) {
            reset();
            List<String> currentParamList = new ArrayList<String>();
            boolean afterBackSlash = false;
            boolean inQuotes = false;
            boolean inBackQuotes = false;
            boolean inSingleLineComment = false;
            boolean inMultiLineComment = false;
            boolean whiteSpace = false;
            Matcher matcher = VALUES.matcher(sql);
            if (matcher.find()) {
                valuesMode = true;
            }
            int currentParensLevel = 0;
            int quotedStart = 0;
            int partStart = 0;
            for (int i = valuesMode ? matcher.end() - 1 : 0, idxStart = i, idxEnd = i; i < sql.length(); i++) {
                char c = sql.charAt(i);
                //注释
                if (inSingleLineComment) {
                    if (c == '\n') {
                        inSingleLineComment = false;
                    }
                } else if (inMultiLineComment) {
                    if (c == '*' && sql.length() > i + 1 && sql.charAt(i + 1) == '/') {
                        inMultiLineComment = false;
                        i++;
                    }
                    //反斜杠
                } else if (afterBackSlash) {
                    afterBackSlash = false;
                } else if (c == '\\') {
                    afterBackSlash = true;
                    //单引号
                } else if (c == '\'') {
                    inQuotes = !inQuotes;
                    if (inQuotes) {
                        quotedStart = i;
                    } else if (!afterBackSlash) {
                        idxStart = quotedStart;
                        idxEnd = i + 1;
                    }
                    //反引号
                } else if (c == '`') {
                    inBackQuotes = !inBackQuotes;
                    //不在单引号中
                } else if (!inQuotes && !inBackQuotes) {
                    if (c == '?') {
                        if (currentParensLevel > 0) {
                            idxStart = i;
                            idxEnd = i + 1;
                        }
                        if (!valuesMode) {
                            parts.add(sql.substring(partStart, i));
                            partStart = i + 1;
                            currentParamList.add(PARAM_MARKER);
                        }
                        //单行注释
                    } else if (c == '-' && sql.length() > i + 1 && sql.charAt(i + 1) == '-') {
                        inSingleLineComment = true;
                        i++;
                        //多行注释
                    } else if (c == '/' && sql.length() > i + 1 && sql.charAt(i + 1) == '*') {
                        inMultiLineComment = true;
                        i++;
                    } else if (c == ',') {
                        if (valuesMode && idxEnd > idxStart) {
                            currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                            parts.add(sql.substring(partStart, idxStart));
                            partStart = idxStart = idxEnd = i;
                        }
                        idxStart++;
                        idxEnd++;
                    } else if (c == '(') {
                        currentParensLevel++;
                        idxStart++;
                        idxEnd++;
                    } else if (c == ')') {
                        currentParensLevel--;
                        if (valuesMode && currentParensLevel == 0) {
                            if (idxEnd > idxStart) {
                                currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                                parts.add(sql.substring(partStart, idxStart));
                                partStart = idxStart = idxEnd = i;
                            }
                            if (!currentParamList.isEmpty()) {
                                parameters.add(currentParamList);
                                currentParamList = new ArrayList<String>(currentParamList.size());
                            }
                        }
                    } else if (Character.isWhitespace(c)) {
                        whiteSpace = true;
                    } else if (currentParensLevel > 0) {
                        if (whiteSpace) {
                            idxStart = i;
                            idxEnd = i + 1;
                        } else {
                            idxEnd++;
                        }
                        whiteSpace = false;
                    }
                }
            }
            if (!valuesMode && !currentParamList.isEmpty()) {
                parameters.add(currentParamList);
            }
            String lastPart = sql.substring(partStart);
            parts.add(lastPart);
        }

        private String buildSql(Object[] params) throws SQLException {
            if (parts.size() == 1) {
                return parts.get(0);
            }
            //checkBinded();检查参数是不是null
            StringBuilder sb = new StringBuilder(parts.get(0));
            for (int i = 1, p = 0; i < parts.size(); i++) {
                String pValue = getParameter(i - 1);
                if (PARAM_MARKER.equals(pValue)) {
                    appendBoundValue(sb, params[p++]);
                } else {
                    sb.append(pValue);
                }
                sb.append(parts.get(i));
            }
            String mySql = sb.toString();
            return mySql;
        }

        private String getParameter(int paramIndex) {
            for (int i = 0, count = paramIndex; i < parameters.size(); i++) {
                List<String> pList = parameters.get(i);
                count -= pList.size();
                if (count < 0) {
                    return pList.get(pList.size() + count);
                }
            }
            return null;
        }

        private void appendBoundValue(StringBuilder sb, Object params) {
            //需要根据参数类型来处理
            try {
                sb.append(getObject(params));
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }

        private StringBuilder getBind(String bind, boolean quote) {
            StringBuilder sb = new StringBuilder();
            if (quote) {
                sb.append("'").append(bind).append("'");
            } else if (bind.equals("\\N")) {
                sb.append("null");
            } else {
                sb.append(bind);
            }
            return sb;
        }

        private StringBuilder getBind(String bind) {
            return getBind(bind, false);
        }

        public StringBuilder getNull(int sqlType) throws SQLException {
            return getBind("\\N");
        }

        public StringBuilder getBoolean(boolean x) throws SQLException {
            return getBind(x ? "1" : "0");
        }

        public StringBuilder getByte(byte x) throws SQLException {
            return getBind(Byte.toString(x));
        }

        public StringBuilder getShort(short x) throws SQLException {
            return getBind(Short.toString(x));
        }

        public StringBuilder getInt(int x) throws SQLException {
            return getBind(Integer.toString(x));
        }

        public StringBuilder getLong(long x) throws SQLException {
            return getBind(Long.toString(x));
        }

        public StringBuilder getFloat(float x) throws SQLException {
            return getBind(Float.toString(x));
        }

        public StringBuilder getDouble(double x) throws SQLException {
            return getBind(Double.toString(x));
        }

        public StringBuilder getBigDecimal(BigDecimal x) throws SQLException {
            return getBind(x.toPlainString());
        }

        public StringBuilder getString(String x) throws SQLException {
            return getBind(ClickHouseUtil.escape(x), x != null);
        }

        public StringBuilder getBytes(byte[] x) throws SQLException {
            return getBind(new String(x, StreamUtils.UTF_8));
        }

        public StringBuilder getDate(java.sql.Date x) throws SQLException {
            return getBind(dateFormat.format(x), true);
        }

        public StringBuilder getTime(Time x) throws SQLException {
            return getBind(dateTimeFormat.format(x), true);
        }

        public StringBuilder getTimestamp(Timestamp x) throws SQLException {
            return getBind(dateTimeFormat.format(x), true);
        }

        public StringBuilder getAsciiStream(InputStream x, int length) throws SQLException {
            throw new SQLFeatureNotSupportedException();
        }

        @Deprecated public StringBuilder getUnicodeStream(InputStream x, int length) throws SQLException {
            throw new SQLFeatureNotSupportedException();
        }

        public StringBuilder getBinaryStream(InputStream x, int length) throws SQLException {
            throw new SQLFeatureNotSupportedException();
        }

        public StringBuilder getObject(Object x, int targetSqlType) throws SQLException {
            return getObject(x);
        }

        public StringBuilder getArray(Collection collection) throws SQLException {
            return getBind(ClickHouseArrayUtil.toString(collection));
        }

        public StringBuilder getArray(Object[] array) throws SQLException {
            return getBind(ClickHouseArrayUtil.toString(array));
        }

        public StringBuilder getClob(Clob x) throws SQLException {
            throw new SQLFeatureNotSupportedException();
        }

        public StringBuilder getBlob(Blob x) throws SQLException {
            throw new SQLFeatureNotSupportedException();
        }

        public StringBuilder getObject(Object x) throws SQLException {
            if (x == null) {
                return getNull(Types.OTHER);
            } else {
                if (x instanceof Byte) {
                    return getInt(((Byte) x).intValue());
                } else if (x instanceof String) {
                    return getString((String) x);
                } else if (x instanceof BigDecimal) {
                    return getBigDecimal((BigDecimal) x);
                } else if (x instanceof Short) {
                    return getShort(((Short) x).shortValue());
                } else if (x instanceof Integer) {
                    return getInt(((Integer) x).intValue());
                } else if (x instanceof Long) {
                    return getLong(((Long) x).longValue());
                } else if (x instanceof Float) {
                    return getFloat(((Float) x).floatValue());
                } else if (x instanceof Double) {
                    return getDouble(((Double) x).doubleValue());
                } else if (x instanceof byte[]) {
                    return getBytes((byte[]) x);
                } else if (x instanceof java.sql.Date) {
                    return getDate((java.sql.Date) x);
                } else if (x instanceof Time) {
                    return getTime((Time) x);
                } else if (x instanceof Timestamp) {
                    return getTimestamp((Timestamp) x);
                } else if (x instanceof Boolean) {
                    return getBoolean(((Boolean) x).booleanValue());
                } else if (x instanceof InputStream) {
                    return getBinaryStream((InputStream) x, -1);
                } else if (x instanceof Blob) {
                    return getBlob((Blob) x);
                } else if (x instanceof Clob) {
                    return getClob((Clob) x);
                } else if (x instanceof BigInteger) {
                    return getBind(x.toString());
                } else if (x instanceof UUID) {
                    return getString(x.toString());
                } else if (x instanceof Collection) {
                    return getArray((Collection) x);
                } else if (x.getClass().isArray()) {
                    return getArray((Object[]) x);
                } else {
                    throw new SQLDataException("Can't bind object of class " + x.getClass().getCanonicalName());
                }
            }
        }
    }
}




使用:

1 添加mybatis拦截器:(注意要添加到第一个,mybatis拦截器是倒序执行。第一个添加的最后一个执行)

SqlSessionFactoryBean sessionFactoryBean = new SqlSessionFactoryBean();
        sessionFactoryBean.setDataSource(dataSource);
        sessionFactoryBean.setVfs(SpringBootVFS.class);
        List<Interceptor> interceptors = new ArrayList<>();
        interceptors.add(new MybatisClickHouseGetSqlInterceptor());
        MybatisPrepareInterceptor mybatisPrepareInterceptor = new MybatisPrepareInterceptor();
        mybatisPrepareInterceptor.setMonitor(this.defaultMonitor);
        interceptors.add(mybatisPrepareInterceptor);
        sessionFactoryBean.setPlugins((Interceptor[])interceptors.toArray(new Interceptor[interceptors.size()]));

2 写mapper:

注意:mapper方法需要满足两个条件:

1) 第一个参数为boolean,且传入true,

2) 返回值为string

其他跟mybatis使用一样,调用完成就会返回替换好参数的sql语句,也不会用有sql注入问题。

("<script>select xxxx<if></script>")

String getsql(boolean isSql,@Param("query") xxxxx,xxx);

注意:因为不同引擎(mysql、hive、clickhosue等)对于特殊字符的处理不同,所以此段代码用于除clickhouse之外的引擎请谨慎使用。如果要替换引擎需要参考对应引擎jdbc的preparstatment方法去替换本代码中的 

String sql = PreparedStatementParser.parse(rawsql).buildSql(params); 这一行。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

i am cscs

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值