背景
当前系统逻辑期望可以通过解析Flink Sql提取使用到的自定义函数,依据实际需要加载必要的资源jar,提交Flink Sql任务时设置PROVIDED_LIB_DIRS的为依赖jar对应hdfs路径,避免加载多余jar和依赖冲突。
目标
解析Flink Sql提取使用到的自定义函数列表。
方案
以Flink SQL任务中最为常见的单条INSERT INTO ... SELECT ...
为例:
INSERT INTO t1
SELECT fun0(func(c1)), count(distinct c2)
FROM (
SELECT (select fun1(c1, c3) from table1 where func2(c2) = 1),
fun3(c4)
from table2 where fun4(c6) = 1 and avg(c9) > 0
)
where fun5(c5) = 1 and max(c6) > 0 and fun6(c7) = 1 and c8 = (select fun7(id) from table2 where func8(id) > 0 ) GROUP BY c8;
1、将每条SQL分离
2、使用calcite解析单条SQL,生成SqlNode树
3、递归遍历SqlNode树,提取SqlBasicCall节点,获取OperatoName,其中函数的节点类型一般为SqlUnresolvedFunction
代码示例:
public class FinkUdfUtil {
public static void main(String[] args) throws Exception {
String sql =
"INSERT INTO t1 \n"
+ " SELECT fun0(func(c1)), count(distinct c2) \n"
+ " FROM ( \n"
+ " SELECT (select fun1(c1, c3) from table1 where func2(c2) = 1), \n"
+ " fun3(c4) \n"
+ " from table2 where fun4(c6) = 1 and avg(c9) > 0\n"
+ " ) \n"
+ " where fun5(c5) = 1 and max(c6) > 0 and fun6(c7) = 1 "
+ " and c8 = (select fun7(id) from table2 where func8(id) > 0 ) GROUP BY c8;";
System.out.println(sql);
// 创建解析器
List<String> result = extraFun(sql);
System.out.println(result);
}
private static List<String> extraFun(String sql) throws SqlParseException {
List<String> functions = Lists.newArrayList();
SqlParser parser = SqlParser.create(sql, SqlParser.config()
.withParserFactory(FlinkSqlParserImpl.FACTORY)
.withQuoting(Quoting.BACK_TICK)
.withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED)
.withConformance(FlinkSqlConformance.DEFAULT)
);
List<SqlNode> sqlNodeList = parser.parseStmtList().getList();
if (CollectionUtils.isEmpty(sqlNodeList)) {
return functions;
}
for (SqlNode sqlNode : sqlNodeList) {
if (!(sqlNode instanceof RichSqlInsert)) {
continue;
}
SqlNode selSQL = ((RichSqlInsert) sqlNode).getSource();
if (selSQL instanceof SqlSelect) {
getSqlSelect((SqlSelect) selSQL, functions);
}
}
return functions;
}
@NotNull
private static void getSqlSelect(SqlSelect selSQL, List<String> functions) {
List<SqlNode> list = selSQL.getSelectList().getList();
for (SqlNode node : list) {
if (node instanceof SqlBasicCall) {
getSqlCall((SqlBasicCall) node, functions);
} else if (node instanceof SqlSelect) {
getSqlSelect((SqlSelect) node, functions);
}
}
SqlNode from = selSQL.getFrom();
if (from != null && from instanceof SqlSelect) {
getSqlSelect((SqlSelect) from, functions);
}
SqlBasicCall where = (SqlBasicCall) selSQL.getWhere();
if (where == null) {
return;
}
getSqlCall(where, functions);
}
private static void getSqlCall(SqlBasicCall sqlCall, List<String> functions) {
List<SqlNode> calls = sqlCall.getOperandList();
if(((SqlCall) sqlCall).getOperator() instanceof SqlUnresolvedFunction){
functions.add(((SqlCall) sqlCall).getOperator().getName());
}
for (SqlNode cNode : calls) {
if (cNode instanceof SqlBasicCall) {
getSqlCall((SqlBasicCall) cNode, functions);
} else if (cNode instanceof SqlSelect) {
getSqlSelect((SqlSelect) cNode, functions);
}
}
}
}