MybatisClickHouseGetSqlInterceptor.java
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import ru.yandex.clickhouse.ClickHouseUtil;
import ru.yandex.clickhouse.util.ClickHouseArrayUtil;
import ru.yandex.clickhouse.util.apache.StringUtils;
import ru.yandex.clickhouse.util.guava.StreamUtils;
import java.io.InputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.*;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.Date;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
@Slf4j
public class MybatisClickHouseGetSqlInterceptor implements Interceptor {
public static void main(String[] args) {
SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
simpleDateFormat.setTimeZone(TimeZone.getDefault());
String format = simpleDateFormat.format(new Date());
System.out.println(format);
}
@Override public Object intercept(Invocation invocation) throws Throwable {
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
//0.sql参数获取
Object parameter = null;
if (invocation.getArgs().length > 1) {
parameter = invocation.getArgs()[1];
}
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
//1.方法没有参数或者第一个参数为false
if (parameter == null || (parameter instanceof ParamMap && (((ParamMap) parameter).size() == 0 || ((ParamMap) parameter).get("param1") != Boolean.TRUE)) || (parameter instanceof Boolean && parameter != Boolean.TRUE) || (!(parameter instanceof Boolean) && !(parameter instanceof ParamMap))) {
//集群提交任务
return invocation.proceed();
}
Configuration configuration = mappedStatement.getConfiguration();
//2. 返回sql语句
Object[] params = getParams(configuration, boundSql);
String rawsql = boundSql.getSql();
String sql = PreparedStatementParser.parse(rawsql).buildSql(params);
return Arrays.asList(sql);
}
@Override public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
@Override public void setProperties(Properties properties) {
}
private static Object[] getParams(Configuration configuration, BoundSql boundSql) {
ArrayList<Object> list = new ArrayList<>();
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
if (!parameterMappings.isEmpty() && parameterObject != null) {
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
list.add(parameterObject);
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
if (metaObject.hasGetter(propertyName)) {
Object obj = metaObject.getValue(propertyName);
list.add(obj);
} else if (boundSql.hasAdditionalParameter(propertyName)) {
Object obj = boundSql.getAdditionalParameter(propertyName);
list.add(obj);
} else {
list.add("缺失");
}//打印出缺失,提醒该参数缺失并防止错位
}
}
}
return list.toArray();
}
/**
* Parser for clickhouse JDBC SQL Strings
* <p>
* Tries to extract query parameters in a way that is usable for (batched)
* prepared statements.
*/
public final static class PreparedStatementParser {
private static final Pattern VALUES = Pattern.compile("(?i)INSERT\\s+INTO\\s+.+VALUES\\s*\\(");
private static final String PARAM_MARKER = "?";
private final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd");
private final SimpleDateFormat dateTimeFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private final List<List<String>> parameters;
private final List<String> parts;
private boolean valuesMode;
private PreparedStatementParser() {
parameters = new ArrayList<List<String>>();
parts = new ArrayList<String>();
valuesMode = false;
initTimeZone(TimeZone.getDefault());
}
public static PreparedStatementParser parse(String sql) {
if (StringUtils.isBlank(sql)) {
throw new IllegalArgumentException("SQL may not be blank");
}
PreparedStatementParser parser = new PreparedStatementParser();
parser.parseSQL(sql);
return parser;
}
private static String typeTransformParameterValue(String paramValue) {
if (paramValue == null) {
return null;
}
if (Boolean.TRUE.toString().equalsIgnoreCase(paramValue)) {
return "1";
}
if (Boolean.FALSE.toString().equalsIgnoreCase(paramValue)) {
return "0";
}
if ("NULL".equalsIgnoreCase(paramValue)) {
return "\\N";
}
return paramValue;
}
private void initTimeZone(TimeZone timeZone) {
dateTimeFormat.setTimeZone(timeZone);
dateFormat.setTimeZone(timeZone);
}
private void reset() {
parameters.clear();
parts.clear();
valuesMode = false;
}
private void parseSQL(String sql) {
reset();
List<String> currentParamList = new ArrayList<String>();
boolean afterBackSlash = false;
boolean inQuotes = false;
boolean inBackQuotes = false;
boolean inSingleLineComment = false;
boolean inMultiLineComment = false;
boolean whiteSpace = false;
Matcher matcher = VALUES.matcher(sql);
if (matcher.find()) {
valuesMode = true;
}
int currentParensLevel = 0;
int quotedStart = 0;
int partStart = 0;
for (int i = valuesMode ? matcher.end() - 1 : 0, idxStart = i, idxEnd = i; i < sql.length(); i++) {
char c = sql.charAt(i);
//注释
if (inSingleLineComment) {
if (c == '\n') {
inSingleLineComment = false;
}
} else if (inMultiLineComment) {
if (c == '*' && sql.length() > i + 1 && sql.charAt(i + 1) == '/') {
inMultiLineComment = false;
i++;
}
//反斜杠
} else if (afterBackSlash) {
afterBackSlash = false;
} else if (c == '\\') {
afterBackSlash = true;
//单引号
} else if (c == '\'') {
inQuotes = !inQuotes;
if (inQuotes) {
quotedStart = i;
} else if (!afterBackSlash) {
idxStart = quotedStart;
idxEnd = i + 1;
}
//反引号
} else if (c == '`') {
inBackQuotes = !inBackQuotes;
//不在单引号中
} else if (!inQuotes && !inBackQuotes) {
if (c == '?') {
if (currentParensLevel > 0) {
idxStart = i;
idxEnd = i + 1;
}
if (!valuesMode) {
parts.add(sql.substring(partStart, i));
partStart = i + 1;
currentParamList.add(PARAM_MARKER);
}
//单行注释
} else if (c == '-' && sql.length() > i + 1 && sql.charAt(i + 1) == '-') {
inSingleLineComment = true;
i++;
//多行注释
} else if (c == '/' && sql.length() > i + 1 && sql.charAt(i + 1) == '*') {
inMultiLineComment = true;
i++;
} else if (c == ',') {
if (valuesMode && idxEnd > idxStart) {
currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
parts.add(sql.substring(partStart, idxStart));
partStart = idxStart = idxEnd = i;
}
idxStart++;
idxEnd++;
} else if (c == '(') {
currentParensLevel++;
idxStart++;
idxEnd++;
} else if (c == ')') {
currentParensLevel--;
if (valuesMode && currentParensLevel == 0) {
if (idxEnd > idxStart) {
currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
parts.add(sql.substring(partStart, idxStart));
partStart = idxStart = idxEnd = i;
}
if (!currentParamList.isEmpty()) {
parameters.add(currentParamList);
currentParamList = new ArrayList<String>(currentParamList.size());
}
}
} else if (Character.isWhitespace(c)) {
whiteSpace = true;
} else if (currentParensLevel > 0) {
if (whiteSpace) {
idxStart = i;
idxEnd = i + 1;
} else {
idxEnd++;
}
whiteSpace = false;
}
}
}
if (!valuesMode && !currentParamList.isEmpty()) {
parameters.add(currentParamList);
}
String lastPart = sql.substring(partStart);
parts.add(lastPart);
}
private String buildSql(Object[] params) throws SQLException {
if (parts.size() == 1) {
return parts.get(0);
}
//checkBinded();检查参数是不是null
StringBuilder sb = new StringBuilder(parts.get(0));
for (int i = 1, p = 0; i < parts.size(); i++) {
String pValue = getParameter(i - 1);
if (PARAM_MARKER.equals(pValue)) {
appendBoundValue(sb, params[p++]);
} else {
sb.append(pValue);
}
sb.append(parts.get(i));
}
String mySql = sb.toString();
return mySql;
}
private String getParameter(int paramIndex) {
for (int i = 0, count = paramIndex; i < parameters.size(); i++) {
List<String> pList = parameters.get(i);
count -= pList.size();
if (count < 0) {
return pList.get(pList.size() + count);
}
}
return null;
}
private void appendBoundValue(StringBuilder sb, Object params) {
//需要根据参数类型来处理
try {
sb.append(getObject(params));
} catch (SQLException e) {
e.printStackTrace();
}
}
private StringBuilder getBind(String bind, boolean quote) {
StringBuilder sb = new StringBuilder();
if (quote) {
sb.append("'").append(bind).append("'");
} else if (bind.equals("\\N")) {
sb.append("null");
} else {
sb.append(bind);
}
return sb;
}
private StringBuilder getBind(String bind) {
return getBind(bind, false);
}
public StringBuilder getNull(int sqlType) throws SQLException {
return getBind("\\N");
}
public StringBuilder getBoolean(boolean x) throws SQLException {
return getBind(x ? "1" : "0");
}
public StringBuilder getByte(byte x) throws SQLException {
return getBind(Byte.toString(x));
}
public StringBuilder getShort(short x) throws SQLException {
return getBind(Short.toString(x));
}
public StringBuilder getInt(int x) throws SQLException {
return getBind(Integer.toString(x));
}
public StringBuilder getLong(long x) throws SQLException {
return getBind(Long.toString(x));
}
public StringBuilder getFloat(float x) throws SQLException {
return getBind(Float.toString(x));
}
public StringBuilder getDouble(double x) throws SQLException {
return getBind(Double.toString(x));
}
public StringBuilder getBigDecimal(BigDecimal x) throws SQLException {
return getBind(x.toPlainString());
}
public StringBuilder getString(String x) throws SQLException {
return getBind(ClickHouseUtil.escape(x), x != null);
}
public StringBuilder getBytes(byte[] x) throws SQLException {
return getBind(new String(x, StreamUtils.UTF_8));
}
public StringBuilder getDate(java.sql.Date x) throws SQLException {
return getBind(dateFormat.format(x), true);
}
public StringBuilder getTime(Time x) throws SQLException {
return getBind(dateTimeFormat.format(x), true);
}
public StringBuilder getTimestamp(Timestamp x) throws SQLException {
return getBind(dateTimeFormat.format(x), true);
}
public StringBuilder getAsciiStream(InputStream x, int length) throws SQLException {
throw new SQLFeatureNotSupportedException();
}
@Deprecated public StringBuilder getUnicodeStream(InputStream x, int length) throws SQLException {
throw new SQLFeatureNotSupportedException();
}
public StringBuilder getBinaryStream(InputStream x, int length) throws SQLException {
throw new SQLFeatureNotSupportedException();
}
public StringBuilder getObject(Object x, int targetSqlType) throws SQLException {
return getObject(x);
}
public StringBuilder getArray(Collection collection) throws SQLException {
return getBind(ClickHouseArrayUtil.toString(collection));
}
public StringBuilder getArray(Object[] array) throws SQLException {
return getBind(ClickHouseArrayUtil.toString(array));
}
public StringBuilder getClob(Clob x) throws SQLException {
throw new SQLFeatureNotSupportedException();
}
public StringBuilder getBlob(Blob x) throws SQLException {
throw new SQLFeatureNotSupportedException();
}
public StringBuilder getObject(Object x) throws SQLException {
if (x == null) {
return getNull(Types.OTHER);
} else {
if (x instanceof Byte) {
return getInt(((Byte) x).intValue());
} else if (x instanceof String) {
return getString((String) x);
} else if (x instanceof BigDecimal) {
return getBigDecimal((BigDecimal) x);
} else if (x instanceof Short) {
return getShort(((Short) x).shortValue());
} else if (x instanceof Integer) {
return getInt(((Integer) x).intValue());
} else if (x instanceof Long) {
return getLong(((Long) x).longValue());
} else if (x instanceof Float) {
return getFloat(((Float) x).floatValue());
} else if (x instanceof Double) {
return getDouble(((Double) x).doubleValue());
} else if (x instanceof byte[]) {
return getBytes((byte[]) x);
} else if (x instanceof java.sql.Date) {
return getDate((java.sql.Date) x);
} else if (x instanceof Time) {
return getTime((Time) x);
} else if (x instanceof Timestamp) {
return getTimestamp((Timestamp) x);
} else if (x instanceof Boolean) {
return getBoolean(((Boolean) x).booleanValue());
} else if (x instanceof InputStream) {
return getBinaryStream((InputStream) x, -1);
} else if (x instanceof Blob) {
return getBlob((Blob) x);
} else if (x instanceof Clob) {
return getClob((Clob) x);
} else if (x instanceof BigInteger) {
return getBind(x.toString());
} else if (x instanceof UUID) {
return getString(x.toString());
} else if (x instanceof Collection) {
return getArray((Collection) x);
} else if (x.getClass().isArray()) {
return getArray((Object[]) x);
} else {
throw new SQLDataException("Can't bind object of class " + x.getClass().getCanonicalName());
}
}
}
}
}
使用:
1 添加mybatis拦截器:(注意要添加到第一个,mybatis拦截器是倒序执行。第一个添加的最后一个执行)
SqlSessionFactoryBean sessionFactoryBean = new SqlSessionFactoryBean();
sessionFactoryBean.setDataSource(dataSource);
sessionFactoryBean.setVfs(SpringBootVFS.class);
List<Interceptor> interceptors = new ArrayList<>();
interceptors.add(new MybatisClickHouseGetSqlInterceptor());
MybatisPrepareInterceptor mybatisPrepareInterceptor = new MybatisPrepareInterceptor();
mybatisPrepareInterceptor.setMonitor(this.defaultMonitor);
interceptors.add(mybatisPrepareInterceptor);
sessionFactoryBean.setPlugins((Interceptor[])interceptors.toArray(new Interceptor[interceptors.size()]));
2 写mapper:
注意:mapper方法需要满足两个条件:
1) 第一个参数为boolean,且传入true,
2) 返回值为string
其他跟mybatis使用一样,调用完成就会返回替换好参数的sql语句,也不会用有sql注入问题。
("<script>select xxxx<if></script>")
String getsql(boolean isSql,@Param("query") xxxxx,xxx);
注意:因为不同引擎(mysql、hive、clickhosue等)对于特殊字符的处理不同,所以此段代码用于除clickhouse之外的引擎请谨慎使用。如果要替换引擎需要参考对应引擎jdbc的preparstatment方法去替换本代码中的
String sql = PreparedStatementParser.parse(rawsql).buildSql(params); 这一行。