话不多说,直接上 验证通过的代码
第一个例子:
package jdbc;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Database;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import org.apache.commons.lang3.StringUtils;
public class GenSelectSqlBySqlParser {
public static void main(String[] args) throws JSQLParserException {
// 入参
String database = "test"; // 数据库
String table = "user"; // 表
String[] fields = {"id", "name", "age", "count(id)", "avg(age)", "sum(a)", "calculate_sum(concat_ws(',', collect_list(id)))"}; // 字段
//String[] functions = {"count(id)", "avg(age)", "sum(a)","calculate_sum(concat_ws(',', collect_list(id)))"}; // 函数(列)数组
String[] fieldsAliases = {"", "", "", "c", "a", "s", "t"}; // 列的别名数组
String where = "age > 20"; // where条件
String groupBy = "name"; // group by
String tableAlias = "temp";
Select select = getSelectStr(database, table, fields, fieldsAliases, where, groupBy, tableAlias);
// 输出sql语句
System.out.println(select);
}
private static Select getSelectStr(String database, String table, String[] fields, String[] fieldsAliases, String where, String groupBy, String alias) throws JSQLParserException {
// 创建一个select对象
Select select = new Select();
// 创建一个plainSelect对象,用于设置各种子句
PlainSelect plainSelect = new PlainSelect();
// 设置数据库
Database db = new Database(database);
db.setDatabaseName(database);
// 设置表
Table t = new Table();
t.setDatabase(db);
t.setName(table);
t.setAlias(new Alias(alias)); // 给表设置别名
// 设置字段
for (int i = 0; i < fields.length; i++) {
Column column = new Column();
//column.setTable(t);
column.setColumnName(fields[i]);
// 解析函数表达式
Expression expr = CCJSqlParserUtil.parseExpression(fields[i]);
SelectExpressionItem item = new SelectExpressionItem();
item.setExpression(expr);
// 设置别名
if(StringUtils.isNotEmpty(fieldsAliases[i])){
item.setAlias(new Alias(fieldsAliases[i]));
}
plainSelect.addSelectItems(item);
}
plainSelect.setFromItem(t);
// 设置where条件
Expression expr = CCJSqlParserUtil.parseCondExpression(where);
plainSelect.setWhere(expr);
// 设置group by
Expression groupByExpr = CCJSqlParserUtil.parseExpression(groupBy);
plainSelect.addGroupByColumnReference(groupByExpr);
// 将plainSelect设置为select的子句
select.setSelectBody(plainSelect);
return select;
}
}
执行结果如下:
SELECT id, name, age, count(id) AS c, avg(age) AS a, sum(a) AS s,
calculate_sum(concat_ws(',', collect_list(id))) AS t
FROM test.user AS temp WHERE age > 20 GROUP BY name
第二个例子:
通过多个单条sql,生成关联sql,也就是 join on
package jdbc;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
public class c1 {
public static void main(String[] args) throws JSQLParserException {
// 输入
// String express = "total_amount * average_amount / user_count + XXX1 - TTT2";
// List<String> sqlList = new ArrayList<>();
// // SQL 语句 A
// String sqlA = "SELECT username, SUM(total_amount) as total_amount FROM hadoop_ind.user_amount GROUP BY username";
// // SQL 语句 B
// String sqlB = "SELECT receiver_name as username, avg(total_amount) as average_amount FROM hadoop_ind.`order` GROUP BY receiver_name";
// // SQL 语句 C
// String sqlC = "SELECT username, count(*) as user_count FROM hadoop_ind.`user` GROUP BY username";
// // SQL 语句 C
// String sqlD = "SELECT username, AVG(*) as XXX1 FROM hadoop_ind.`XXX` GROUP BY username";
// // SQL 语句 C
// String sqlE = "SELECT username, MAX(*) as TTT2 FROM hadoop_ind.`TTT` GROUP BY username";
// sqlList.add(sqlA);
// sqlList.add(sqlB);
// sqlList.add(sqlC);
// sqlList.add(sqlD);
// sqlList.add(sqlE);
List<String> sqlList = new ArrayList<>();
//String express = " ( ( min(a)+max(b)) + (min(c)+max(d)) )*10";
String express = " (t1 + t2) *10";
String sqlA = "select min(a)+max(b) as t1 from db1.table1";
String sqlB = "select min(c)+max(d) as t2 from db2.table2";
sqlList.add(sqlA);
sqlList.add(sqlB);
// 输出
String result = combineSql(express, sqlList);
System.out.println(result);
}
/**
* 整体的拼接 复合指标的sql 逻辑
*
* @param express 表达式
* @param sqlList 衍生指标sql集合
* @return 处理后的sql
*/
public static String combineSql(String express, List<String> sqlList) throws JSQLParserException {
// 存储每个sql语句对应的别名
HashMap<String, String> aliasMap = new HashMap<>();
for (int i = 0; i < sqlList.size(); i++) {
char c = (char) ('A' + i);
aliasMap.put(String.valueOf(c), sqlList.get(i));
}
// 解析每个sql语句,获取其中的select items, from items, where items, group by items等
Map<String, PlainSelect> plainSelectMap = getPlainSelectMap(aliasMap);
// 构造子查询
List<SubSelect> subSelectList = getSubSelects(aliasMap, plainSelectMap);
// 构造父查询
PlainSelect parentSelect = new PlainSelect();
// 设置select items
List<SelectItem> selectItems = new ArrayList<>();
//这里处理非函数列的列名
setColumnName(plainSelectMap, selectItems);
//这里处理函数列的列名(目前只支持 一个函数计算的表达式)
setFunctionName(express, aliasMap, selectItems, parentSelect);
// 设置from item,只拿第一个sql 作为from ,后面的sql 作为 join
parentSelect.setFromItem(subSelectList.get(0));
// 设置join items
parentSelect.setJoins(getJoins(plainSelectMap, subSelectList));
// 生成拼接后的字符串
Select finalSelect = new Select();
finalSelect.setSelectBody(parentSelect);
return finalSelect.toString();
}
/**
* join items
*/
private static List<Join> getJoins(Map<String, PlainSelect> plainSelectMap, List<SubSelect> subSelectList) {
List<Join> joins = new ArrayList<>();
for (int i = 1; i < subSelectList.size(); i++) {
Join join = new Join();
join.setRightItem(subSelectList.get(i));
final Expression onExpression = getOnExpression(plainSelectMap, i);
if (Objects.nonNull(onExpression)) {
join.addOnExpression(onExpression);
}
joins.add(join);
}
return joins;
}
/**
* 获取子查询
*/
private static List<SubSelect> getSubSelects(Map<String, String> aliasMap, Map<String, PlainSelect> plainSelectMap) {
List<SubSelect> subSelectList = new ArrayList<>();
for (String key : aliasMap.keySet()) {
PlainSelect plainSelect = plainSelectMap.get(key);
SubSelect subSelect = new SubSelect();
subSelect.setAlias(new Alias(key));
subSelect.setSelectBody(plainSelect);
subSelectList.add(subSelect);
}
return subSelectList;
}
/**
* 解析每个sql语句,获取其中的select items, from items, where items, group by items等
*/
private static Map<String, PlainSelect> getPlainSelectMap(Map<String, String> aliasMap) throws JSQLParserException {
Map<String, PlainSelect> plainSelectMap = new HashMap<>();
for (Map.Entry<String, String> entry : aliasMap.entrySet()) {
String sql = entry.getValue(); // 获取键
Select select = (Select) CCJSqlParserUtil.parse(sql);
PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
plainSelectMap.put(entry.getKey(), plainSelect);
}
return plainSelectMap;
}
/**
* 获取非函数的字段列
* 这里去遍历每个sql里面的 非函数字段,
* 取出来后,看是否有别名,没有别名就取原本的名称,
* 然后进行分组,再比较是否一致
*
* @param plainSelectMap 查询sql 和别名的映射map
* @return 字段列
*/
public static List<String> getOnColumnName(
Map<String, PlainSelect> plainSelectMap
) {
List<String> columnList = new ArrayList<>();
String columnName;
for (Map.Entry<String, PlainSelect> entry : plainSelectMap.entrySet()) {
String key = entry.getKey(); // 获取键
PlainSelect plainSelect = entry.getValue(); // 获取值
List<SelectItem> subSelectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : subSelectItems) {
if (selectItem instanceof SelectExpressionItem) {
Expression expression = ((SelectExpressionItem) selectItem).getExpression();
if (expression instanceof Column) {
final Alias alias = ((SelectExpressionItem) selectItem).getAlias();
String aliasName = alias == null ? selectItem.toString() : alias.getName();
columnName = key + "." + aliasName;
columnList.add(columnName);
}
}
}
}
return columnList;
}
/**
* 这里处理非函数列的列名
*
* @param plainSelectMap 查询sql 和别名的映射map
*/
private static void setColumnName(
Map<String, PlainSelect> plainSelectMap,
List<SelectItem> selectItems
) {
List<String> columnList = getOnColumnName(plainSelectMap);
if (CollectionUtils.isEmpty(columnList)) {
return;
}
int i = 1;
// 这里取出 去匹配各个表的 column name 是否一致 ,一致就取出第一个
final List<String> stringList = compareGroups(columnList);
for (String column : stringList) {
//然后去设置 父查询的 select 的非函数字段
SelectExpressionItem selectExpressionItem0 = new SelectExpressionItem();
selectExpressionItem0.setExpression(new Column(column));
selectExpressionItem0.setAlias(new Alias(String.format("column%d", i++)));
selectItems.add(selectExpressionItem0);
}
}
/**
* 这里拼接 函数的名称
*
* @param express String express = "total_amount-average_amount+user_count";
* @param aliasMap A ,B,C
* parts数组:A.total_amount,B.average_amount,C.user_count
* @return 期望结果: A.total_amount-B.average_amount+C.user_count
*/
private static String getFunctionName(String express, HashMap<String, String> aliasMap) {
// 分割表达式
String[] parts = express.split("(?=[\\+\\-\\*\\/\\(\\)])|(?<=[\\+\\-\\*\\/\\(\\)])", -1);
// 创建字符串缓冲区
StringBuilder sb = new StringBuilder();
// 遍历字符串数组
for (String part : parts) {
// 判断是否是一个变量
part = part.trim();
if (part.matches("[a-zA-Z][a-zA-Z0-9]*")) {
// 查找对应的value
for (String key : aliasMap.keySet()) {
if (aliasMap.get(key).contains(part)) {
String value = aliasMap.get(key);
if (value != null) {
// 追加替换后的内容
sb.append(key).append(".").append(part);
} else {
// 追加原变量名
sb.append(part);
}
break;
}
}
} else {
// 追加运算符或括号
sb.append(part);
}
}
// 返回结果
return sb.toString();
}
/**
* 设置 函数,拼接最后的名称
*/
private static void setFunctionName(
String express,
HashMap<String, String> aliasMap,
List<SelectItem> selectItems,
PlainSelect parentSelect
) throws JSQLParserException {
String finalResult = getFunctionName(express, aliasMap);
SelectExpressionItem selectExpressionItem = new SelectExpressionItem();
assert finalResult != null;
selectExpressionItem.setExpression(CCJSqlParserUtil.parseExpression(finalResult));
selectExpressionItem.setAlias(new Alias("result"));
selectItems.add(selectExpressionItem);
parentSelect.setSelectItems(selectItems);
}
/**
* 用于拼接字段名称,根据分组进行比对
* 比如三个sql语句,A有 username,有id,B和C都有,那么就是输出 [A.id, A.username]
* 比如三个sql语句,A有 username,B和C都有,那么就是输出 [A.username]
*
* @param dataList 数据集合 [A.username, A.id, B.username, B.id, C.username, C.id]
* @return 返回的期望结果
*/
private static List<String> compareGroups(List<String> dataList) {
Map<String, List<String>> categorizedData = dataList.stream()
.map(item -> item.split("\\."))
.filter(parts -> parts.length == 2)
.map(parts -> new AbstractMap.SimpleEntry<>(parts[0], parts[1]))
.collect(Collectors.groupingBy(Map.Entry::getKey, Collectors.mapping(Map.Entry::getValue, Collectors.toList())));
//假设你的Map叫做categorizedData
//创建一个Set来存储相同的值
Set<String> commonValues = new HashSet<>();
for (String value : categorizedData.values().iterator().next()) {
//假设这个值是在所有分组中都存在的
boolean isCommon = true;
//遍历Map的其他分组
for (List<String> list : categorizedData.values()) {
//如果这个分组不包含这个值,说明它不是公共的,跳出循环
if (!list.contains(value)) {
isCommon = false;
break;
}
}
//如果这个值是公共的,把它加入到Set中
if (isCommon) {
//遍历Map的第一个分组的值
Iterator<String> iterator = categorizedData.keySet().iterator();
commonValues.add(iterator.next() + "." + value);
}
}
//输出Set中的所有元素
return new ArrayList<>(commonValues);
}
/**
* 获取on条件
* 这里假设每个sql语句都有一个相同的列名作为连接条件,你可以根据你的需求修改这个方法
*
* @param plainSelectMap 查询sql 和别名的映射map
* @param index 索引
* @return on条件的内容
*/
public static Expression getOnExpression(Map<String, PlainSelect> plainSelectMap, int index) {
// 抽取出两个常量,避免重复计算
final String key1 = String.valueOf((char) ('A' + index - 1));
final String key2 = String.valueOf((char) ('A' + index));
// 获取两个 PlainSelect 对象
PlainSelect plainSelect1 = plainSelectMap.get(key1);
PlainSelect plainSelect2 = plainSelectMap.get(key2);
// 获取两个 SelectItem 列表
List<SelectItem> selectItems1 = plainSelect1.getSelectItems();
List<SelectItem> selectItems2 = plainSelect2.getSelectItems();
// 获取两个列名
String columnName1 = getOnColumnName(selectItems1);
String columnName2 = getOnColumnName(selectItems2);
// 如果列名为空,抛出异常
if (!Objects.equals(columnName1, columnName2)) {
throw new RuntimeException("No common column name found");
}
if (StringUtils.isEmpty(columnName1)) {
//如果为空,则获取 group by的 字段
columnName1 = getOnGroupColumnName(plainSelect1);
}
if (StringUtils.isEmpty(columnName2)) {
//如果为空,则获取 group by的 字段
columnName2 = getOnGroupColumnName(plainSelect2);
}
// 如果 表中没有 group by 字段 那么就不给 on条件即可
if (StringUtils.isEmpty(columnName1) || StringUtils.isEmpty(columnName2)) {
return null;
}
// 创建两个 Column 对象
Column column1 = new Column(key1 + "." + columnName1);
Column column2 = new Column(key2 + "." + columnName2);
// 返回一个 EqualsTo 对象
return new EqualsTo(column1, column2);
}
// 抽取出一个辅助方法,用来从 SelectItem 列表中获取列名
private static String getOnColumnName(List<SelectItem> selectItems) {
// 遍历 SelectItem 列表
for (SelectItem selectItem : selectItems) {
// 如果是 SelectExpressionItem 类型的对象
if (selectItem instanceof SelectExpressionItem) {
// 获取表达式和别名
Expression expression = ((SelectExpressionItem) selectItem).getExpression();
Alias alias = ((SelectExpressionItem) selectItem).getAlias();
// 如果表达式是 Column 类型的对象
if (expression instanceof Column) {
// 获取列名和别名
String name = ((Column) expression).getColumnName();
return alias == null ? name : alias.getName();
}
}
}
// 如果没有找到匹配的列名或别名,返回 null
return null;
}
private static String getOnGroupColumnName(PlainSelect plainSelect) {
GroupByElement groupByElement = plainSelect.getGroupBy();
if (Objects.isNull(groupByElement)) {
return "";
}
final ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
final List<Expression> expressions = groupByExpressionList.getExpressions();
// 遍历 SelectItem 列表
StringBuilder name = new StringBuilder();
for (Expression s : expressions) {
name.append(((Column) s).getColumnName()).append(",");
}
name.deleteCharAt(name.length() - 1);
return name.toString();
}
}
执行结果:
SELECT
(A.t1 + B.t2) * 10 AS result
FROM
(
SELECT
min(a) + max(b) AS t1
FROM
db1.table1
GROUP BY
name) AS A
JOIN (
SELECT
min(c) + max(d) AS t2
FROM
db2.table2
GROUP BY
name) AS B
ON A.name = B.name