package com.test.examination.manage.util;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.bjupi.exception.MessageException;
import lombok.Data;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
/**
* SQL字段转换成大写加下划线命名格式
* @author : maxiaojie
* @data : 2023/1/30 9:24
*/
public class SqlParseHandleUtil {
public static String handle(String sql) {
return handleSelect(sql,true);
}
private static String getUper(String text) {
StringBuilder builder = new StringBuilder(text.replace('.', '_'));
for (int i = 1, maxLength = builder.length() - 1; i < maxLength; i++) {
if (isUnderscoreRequired(builder.charAt(i - 1), builder.charAt(i), builder.charAt(i + 1))) {
builder.insert(i++, '_');
}
}
String oc_column = builder.toString().toUpperCase();
return oc_column;
}
private static boolean isUnderscoreRequired(char before, char current, char after) {
return Character.isLowerCase(before) && Character.isUpperCase(current) && Character.isLowerCase(after);
}
/**
* 处理查询sql
* @param sql
* @return
*/
private static String handleSelect(String sql,boolean first) {
SQLInner sqlInner = null;
if(first) {
sqlInner = extractHandle(sql);
}else {
sqlInner = new SQLInner();
sqlInner.setSql(sql);
}
String sqlReturn = sqlInner.getSql();
// 新建 MySQL Parser
MySqlStatementParser parser = new MySqlStatementParser(sqlReturn);
// 使用Parser解析生成AST,这里SQLStatement就是AST
SQLStatement statement = parser.parseStatement();
// 使用visitor来访问AST
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
statement.accept(visitor);
Map<String, String> aliasMap = visitor.getAliasMap();
Map<String,List<String>> table_alias = new HashMap<>();
if(aliasMap != null) {
aliasMap.forEach((k,v)->{
List<String> list = table_alias.get(v);
if(list == null) {
list= new ArrayList<>();
}
list.add(k);
if(sql.startsWith("update ") || sql.startsWith(" update ") || sql.startsWith("insert ") || sql.startsWith(" insert ") || sql.startsWith("INSERT ") || sql.startsWith(" INSERT ")){
list.add("UNKNOWN");
}
table_alias.put(v,list);
});
}
if(table_alias.size() > 0) {
table_alias.forEach((k,list)->{
if(k != null && list != null && list.size() == 1 && list.get(0).length() == k.length()) {
list.add("UNKNOWN");
}
});
}
Collection<TableStat.Column> columns = visitor.getColumns();
for(TableStat.Column v : columns) {
if("*".equals(v.getName())){
continue;
}
List<String> list = table_alias.get(v.getTable());
if(list == null) {
String tempStr = "";
String temptraget = "";
if("UNKNOWN".equals(v.getTable())) {
tempStr = v.getName();
temptraget = getUper(v.getName());
}else {
if(sqlReturn.startsWith("delete ") || sqlReturn.startsWith(" delete ")) {
//当是update语句,并且没有别名的时候,来到这里
tempStr = v.getName();
temptraget = getUper(v.getName());
}
}
sqlReturn = replaceRule(sqlReturn,tempStr,temptraget);
}else {
for(String s : list) {
String tempStr;
String temptraget;
if("UNKNOWN".equals(v.getTable()) || "UNKNOWN".equals(s)) {
tempStr = v.getName();
temptraget = getUper(v.getName());
}else {
tempStr = s+"."+v.getName();
temptraget = s+"."+getUper(v.getName());
}
sqlReturn = replaceRule(sqlReturn,tempStr,temptraget);
}
}
}
if(sqlInner.getKSql() != null && sqlInner.getKSql().size() > 0) {
LinkedHashMap<String, String> kSql = sqlInner.getKSql();
Iterator<String> iterator = kSql.keySet().iterator();
while (iterator.hasNext()){
String key = iterator.next();
String sq = kSql.get(key);
sqlReturn = sqlReturn.replaceAll(key,sq);
}
}
return sqlReturn;
}
/**
* 字符串替换规则
* @param sqlReturn
* @param tempStr
* @param temptraget
* @return
*/
private static String replaceRule(String sqlReturn,String tempStr,String temptraget) {
sqlReturn = sqlReturn.replaceAll(" "+tempStr+","," "+temptraget+",");
sqlReturn = sqlReturn.replaceAll(","+tempStr+",",","+temptraget+",");
sqlReturn = sqlReturn.replaceAll(" "+tempStr+" "," "+temptraget+" ");
sqlReturn = sqlReturn.replaceAll(","+tempStr+" ",","+temptraget+" ");
sqlReturn = sqlReturn.replaceAll(","+tempStr+"=",","+temptraget+"=");
sqlReturn = sqlReturn.replaceAll(" "+tempStr+"="," "+temptraget+"=");
sqlReturn = sqlReturn.replaceAll("`"+tempStr+"`","`"+temptraget+"`");
sqlReturn = sqlReturn.replaceAll("="+tempStr+" ","="+temptraget+" ");
sqlReturn = sqlReturn.replaceAll(" "+tempStr," "+temptraget);
sqlReturn = sqlReturn.replaceAll("\\("+tempStr+",","\\("+temptraget+",");
sqlReturn = sqlReturn.replaceAll("\\("+tempStr+" ","\\("+temptraget+" ");
sqlReturn = sqlReturn.replaceAll("\\("+tempStr+"=","\\("+temptraget+"=");
sqlReturn = sqlReturn.replaceAll("\\("+tempStr+"\\)","\\("+temptraget+"\\)");
sqlReturn = sqlReturn.replaceAll(","+tempStr+"\\)",","+temptraget+"\\)");
return sqlReturn;
}
/**
* 处理字SQL
* @param sqlOrigin
* @return
*/
private static SQLInner extractHandle(String sqlOrigin) {
List<String> list = extractMessageByRegular(sqlOrigin);
SQLInner sqlInner = new SQLInner();
if(list != null && list.size() > 0) {
List<String> inners = new ArrayList<>();
LinkedHashMap<String,String> kSql = new LinkedHashMap<>();
int i = 1;
for(String s : list) {
if(s.startsWith("select ") || s.startsWith(" select ")) {
String s1 = genInitNumber(10, i);
try {
sqlOrigin = sqlOrigin.replaceAll(s,s1);
}catch (PatternSyntaxException exception){
throw new MessageException("SQL NESTED EXCEPTION");
}
inners.add(s);
String s2 = handleSelect(s, false);
kSql.put(s1,s2);
i++;
}
}
sqlInner.setInners(inners);
sqlInner.setKSql(kSql);
}
sqlInner.setSql(sqlOrigin);
return sqlInner;
}
/**
* 提取括号中的字符串
* @param msg
* @return
*/
private static List<String> extractMessageByRegular(String msg) {
String pattern = "\\(([^\\)]+)\\)";
Matcher matcher = Pattern.compile(pattern).matcher(msg);
List<String> list = new ArrayList<>();
while (matcher.find()) {
String output = matcher.group(1);
list.add(output);
}
return list;
}
/**
* 内部处理SQL
*/
@Data
static class SQLInner {
String sql;
List<String> inners;
LinkedHashMap<String,String> kSql;
}
/**
* 获取某位数的初始号
* @return
*/
private static String genInitNumber(int number,int num) {
NumberFormat formatter = NumberFormat.getNumberInstance();
formatter.setMinimumIntegerDigits(num);
formatter.setGroupingUsed(false);
return formatter.format(number);
}
}