背景及介绍:
Mybatis支持动态SQL,这一特性使得我们再项目中使用SQL时变得简单易用。众所周知,动态SQL是依靠一堆在XML文件中配置的各种标签(if,foreach,where,choose等)来实现的,使用这些标签能够减少许多判断和拼接SQL工作,让我们的代码看起来更加整洁和规范。那么,是否可以通过动态构建XML模板,来达到构建并解析动态SQL的能力呢? 接下来的这部分代码,完美实现了这一点。
代码实现:
1.引入依赖:
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
<version>3.5.9</version>
</dependency>
2.创建SQLSource:
初始化并创建一个动态sqlSOURCE
,Mybatis中常见的有两种source :DynamicSqlSource
和RawSqlSource
。在createSqlSource()
方法,通过xPathParser
解析器解析标签节点树xNodes
,然后取节点树的root节点,构建一个sqlSource
。通过parseSQLNode()
方法,来判断构建的sqlSource
是动态DynamicSqlSource
还是RawSqlSource
,然后拿到SQL的语法树List。
public class SQLSourceCreator {
protected Logger Logger = LoggerFactory.getLogger(this.getClass());
protected XPathParser xPathParser;
protected Configuration configuration;
private static final String EVAL_NODE = "select";
private static final String ROOT_SQL_NODE = "rootSqlNode";
private static final String CONTENTS = "contents";
protected SqlSource sqlSource;
private SQLSourceCreator() {
}
public SQLSourceCreator(XPathParser xPathParser, Configuration configuration) {
this.xPathParser = xPathParser;
if (configuration == null) {
this.configuration = new Configuration();
}
}
public void init() {
createSqlSource();
}
private void createSqlSource() {
List<XNode> xNodes = xPathParser.evalNodes(EVAL_NODE);
if (xNodes == null || xNodes.size() != 1) {
Logger.info("动态SQL解析异常");
throw new ApiException(ErrorCode.Business.DYNAMIC_SQL_PARSE_ERROR);
}
LanguageDriver languageDriver = configuration.getLanguageDriver(null);
XNode node = xNodes.get(0);
this.sqlSource = languageDriver.createSqlSource(configuration, node, null);
}
public List<SqlNode> parseSQLNode() {
if (sqlSource instanceof DynamicSqlSource) {
DynamicSqlSource dynamicSqlSource = (DynamicSqlSource) sqlSource;
SqlNode sqlNode = (SqlNode) ReflectUtil.reflectPrivateFiled(ROOT_SQL_NODE, dynamicSqlSource);
if (sqlNode != null) {
return (List<SqlNode>) ReflectUtil.reflectPrivateFiled(CONTENTS, sqlNode);
}
}
return null;
}
public Set<String> parseDynamicRequestParam(List<SqlNode> sqlNodeList) {
DynamicSQLParser dynamicSQLParser = new DynamicSQLParser();
return dynamicSQLParser.parseDynamic(sqlNodeList);
}
public String parseStaticReturnParam(List<SqlNode> sqlNodeList) {
DynamicSQLParser dynamicSQLParser = new DynamicSQLParser();
return dynamicSQLParser.parseStatic(sqlNodeList);
}
}
3.创建驱动适配器:
创建SQL解析驱动适配器抽象类,提供一个抽象类接口,适配不同驱动的适配器,其中parse()
是一个抽象方法,具体需要根据不同的驱动去实现,staticReturnParam()
方法是解析SQL返回值的方法,目前是公共的,也可以根据需求动态调整。
public abstract class AbstractDriverAdapter<T> {
public abstract T parse(String apiScript);
/**
* 解析sql返回值参数列表
*
* @param sql
* @param driverType
* @return
*/
public Set<String> staticReturnParam(String sql, DbType driverType) {
List<SQLStatement> sqlStatements = SQLUtils.parseStatements(sql, driverType);
Set<String> returnParams = new HashSet<>();
try {
sqlStatements.forEach(sqlStatement -> {
if (sqlStatement instanceof SQLSelectStatement) {
SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
MySqlSelectQueryBlock selectQuery = (MySqlSelectQueryBlock) sqlSelectStatement.getSelect().getQuery();
List<SQLSelectItem> selectList = selectQuery.getSelectList();
selectList.forEach(selectItem -> {
String column = selectItem.getAlias().replaceAll("`", "");
if (column.trim().equals("*")) {
throw new ApiException(ErrorCode.Business.FULL_COLUMN_QUERY_ERROR);
}
if (StringUtil.checkValNull(column)) {
column = selectItem.toString().replaceAll("`", "");
}
returnParams.add(column);
});
}
});
} catch (ApiException aex) {
throw new ApiException(aex, ErrorCode.Business.RETURN_COLUMN_PARSE_ERROR);
}
return returnParams;
}
}
4.实现MysqlDriverAdapter:
MysqlDriverAdapter
集成了抽象类AbstractDriverAdapter
并实现其中的parse()方法来。在这个方法中,要重点关注create()
方法,从create()
方法中,可以看到,首先构建了一个带select
标签的xml解析器XPathParser
,然后将xPathParser
传递给SQLSourceCreator
来初始化一个成员变量动态sqlSourceCreator
,然后再 parseReqParams()
方法中,根据sqlSourceCreator.parseSQLNode()
来获取sqlNodeList
,
@Component
public class MysqlDriverAdapter extends AbstractDriverAdapter<ParseParamsDTO> {
private static final String XML = "<select>%s</select>";
private static final String TEXT = "text";
private SQLSourceCreator sqlSourceCreator;
public void create(String apiScript) {
String apiScriptDecode = new String(Base64Utils.decodeFromString(apiScript));
String xmlFormat = String.format(XML, apiScriptDecode);
XPathParser xPathParser = new XPathParser(xmlFormat);
SQLSourceCreator sqlSourceCreator = new SQLSourceCreator(xPathParser, null);
sqlSourceCreator.init();
this.sqlSourceCreator = sqlSourceCreator;
}
@Override
public ParseParamsDTO parse(String apiScript) {
create(apiScript);
List<SqlNode> sqlNodeList = sqlSourceCreator.parseSQLNode();
ParseParamsDTO parseParamsDTO = new ParseParamsDTO();
//解析动态请求参数
List<ParseParamsDTO.RequestParams> requestParams = parseReqParams(sqlNodeList);
parseParamsDTO.setRequestParams(requestParams);
//解析静态返回参数
List<ParseParamsDTO.ReturnParams> returnParams = parseReturnParams(sqlNodeList);
parseParamsDTO.setReturnParams(returnParams);
return parseParamsDTO;
}
/**
解析请求参数
*/
public List<ParseParamsDTO.RequestParams> parseReqParams(List<SqlNode> sqlNodeList) {
Set<String> reqParams = sqlSourceCreator.parseDynamicRequestParam(sqlNodeList);
List<ParseParamsDTO.RequestParams> requestList = Lists.newArrayList();
if (CollectionUtils.isNotEmpty(reqParams)) {
reqParams.forEach(param -> {
ParseParamsDTO.RequestParams requestP = ParseParamsDTO.buildReqParams();
requestP.setColumnName(param);
requestList.add(requestP);
});
}
return requestList;
}
/**
解析返回参数
*/
public List<ParseParamsDTO.ReturnParams> parseReturnParams(List<SqlNode> sqlNodeList) {
String staticSql = sqlSourceCreator.parseStaticReturnParam(sqlNodeList);
List<ParseParamsDTO.ReturnParams> returnList = Lists.newArrayList();
if (null != staticSql) {
Set<String> returnParams = super.staticReturnParam(staticSql, DbType.mysql);
returnParams.forEach(rp -> {
ParseParamsDTO.ReturnParams returnP = ParseParamsDTO.buildReturnParams();
returnP.setColumnName(rp);
returnList.add(returnP);
});
}
return returnList;
}
}
5.创建解析器抽象类:
public abstract class AbstractSQLParser {
/**
* 解析动态请求参数
*
* @param sqlNodeList
* @return
*/
public Set<String> parseDynamic(List<SqlNode> sqlNodeList) {
Set<String> requestParamList = new HashSet<>();
sqlNodeList.forEach(sqlNode -> {
if (sqlNode instanceof IfSqlNode) {
IfSQLNodeParse ifsqlNodeParse = new IfSQLNodeParse();
requestParamList.addAll(ifsqlNodeParse.sqlNodeParse(sqlNode, requestParamList));
return;
}
if (sqlNode instanceof ForEachSqlNode) {
ForeachSQLNodeParse foreachSQLNodeParse = new ForeachSQLNodeParse();
requestParamList.addAll(foreachSQLNodeParse.sqlNodeParse(sqlNode, requestParamList));
return;
}
if (sqlNode instanceof TextSqlNode) {
TextSQLNodeParse textSQLNodeParse = new TextSQLNodeParse();
requestParamList.addAll(textSQLNodeParse.sqlNodeParse(sqlNode, requestParamList));
}
});
return requestParamList;
}
/**
* 解析静态返回参数
*
* @param sqlNodeList
* @return
*/
public String parseStatic(List<SqlNode> sqlNodeList) {
for (SqlNode sqlNode : sqlNodeList) {
if (!(sqlNode instanceof StaticTextSqlNode)) {
continue;
}
StaticSQLNodeParse staticSQLNodeParse = new StaticSQLNodeParse();
String staticSql = staticSQLNodeParse.sqlNodeParse(sqlNode, null);
if (staticSql != null) {
return staticSql;
}
}
return null;
}
}
7.创建动态SQL解析器:
因为目前默认是MYSQL动态SQL解析,所以使用默认就好。
public class DynamicSQLParser extends AbstractSQLParser {
}
8.实现不同XML节点解析:
public abstract class SQLNodeParse<T> {
public abstract T sqlNodeParse(SqlNode sqlNode, Set<String> requestParams);
}
Mybatis中不同XML节点的具体解析实现:
IfSQLNodeParse:
public class IfSQLNodeParse extends SQLNodeParse<Set<String>> {
private static final String CONTENTS = "contents";
@Override
public Set<String> sqlNodeParse(SqlNode sqlNode, Set<String> requestParams) {
SqlNode ifSqlNode = (SqlNode) ReflectUtil.reflectPrivateFiled(CONTENTS, sqlNode);
if (ifSqlNode instanceof MixedSqlNode) {
MixSQLNodeParse mixSQLNodeParse = new MixSQLNodeParse();
requestParams.addAll(mixSQLNodeParse.sqlNodeParse(ifSqlNode, requestParams));
}
return requestParams;
}
}
ForeachSQLNodeParse:
public class ForeachSQLNodeParse extends SQLNodeParse<Set<String>> {
private static final String COLLECTION_EXPRESSION = "collectionExpression";
@Override
public Set<String> sqlNodeParse(SqlNode sqlNode,Set<String> requestParams) {
if (sqlNode instanceof ForEachSqlNode) {
ForEachSqlNode forEachSqlNode = (ForEachSqlNode) sqlNode;
String foreachParam = (String) ReflectUtil.reflectPrivateFiled(COLLECTION_EXPRESSION, forEachSqlNode);
requestParams.add(foreachParam);
}
return requestParams;
}
}
MixSQLNodeParse:
public class MixSQLNodeParse extends SQLNodeParse<Set<String>> {
private static final String CONTENTS = "contents";
@Override
public Set<String> sqlNodeParse(SqlNode sqlNode, Set<String> requestParams) {
List<SqlNode> sqlNodeList = (List<SqlNode>) ReflectUtil.reflectPrivateFiled(CONTENTS, sqlNode);
DynamicSQLParser dynamicSQLParser = new DynamicSQLParser();
requestParams.addAll(dynamicSQLParser.parseDynamic(sqlNodeList));
return requestParams;
}
}
TextSQLNodeParse:
public class TextSQLNodeParse extends SQLNodeParse<Set<String>> {
private static final String TEXT = "text";
private static final Pattern REGEX = Pattern.compile("\\$\\{([a-zA-z_0-9]+)\\}");
@Override
public Set<String> sqlNodeParse(SqlNode sqlNode, Set<String> requestParams) {
if (sqlNode instanceof TextSqlNode) {
TextSqlNode textSqlNode = (TextSqlNode) sqlNode;
if (textSqlNode.isDynamic()) {
String text = (String) ReflectUtil.reflectPrivateFiled(TEXT, textSqlNode);
Matcher matcher = REGEX.matcher(text);
while (matcher.find()) {
String group = matcher.group();
String param = group.substring(group.indexOf("{") + 1, group.indexOf("}"));
requestParams.add(param);
}
}
}
return requestParams;
}
}
StaticSQLNodeParse:
public class StaticSQLNodeParse extends SQLNodeParse<String> {
private static final String TEXT = "text";
@Override
public String sqlNodeParse(SqlNode sqlNode, Set<String> requestParams) {
if (sqlNode instanceof StaticTextSqlNode) {
String staticSql = (String) ReflectUtil.reflectPrivateFiled(TEXT, sqlNode);
if (staticSql.trim().startsWith("select") ||
staticSql.trim().startsWith("SELECT")) {
return staticSql;
}
}
return null;
}
}
9.工具类及实体:
public class ReflectUtil {
protected static Logger Logger = LoggerFactory.getLogger(ReflectUtil.class);
public static Object reflectPrivateFiled(String declaredFieldName, Object sourceObject) {
try {
Field declaredField = sourceObject.getClass().getDeclaredField(declaredFieldName);
declaredField.setAccessible(true);
return declaredField.get(sourceObject);
} catch (ReflectiveOperationException re) {
Logger.info("反射获取私有属性出错");
throw new RuntimeException("", re);
}
}
}
@Data
public class ParseParamsDTO {
@Data
public static class RequestParams {
private String columnName;
private String columnType;
private boolean required;
private String columnDesc;
@ApiModelProperty("参数样例")
private String requestDemo;
}
@Data
public static class ReturnParams {
private String columnName;
private String columnDesc;
private String columnType;
}
@ApiModelProperty("请求参数")
private List<RequestParams> requestParams;
@ApiModelProperty("返回参数")
private List<ReturnParams> returnParams;
public static ParseParamsDTO.RequestParams buildReqParams() {
return new RequestParams();
}
public static ParseParamsDTO.ReturnParams buildReturnParams() {
return new ReturnParams();
}
}
测试:
原始SQL如下:
SELECT
id as dataId ,
year_month as yearMonth ,
`year` as `year`,
`month` as `month`,
own_code as ownCode,
function as function,
center_code as centerCode,
center_name as centerName ,
related_party as relatedParty,
big_region as bigRegion,
management_org as managementOrg,
hr_org as srOrg
from dwd.`test_demo` where 1=1
<if test= "yearMonth != null and yearMonth !='' " > and year_month='${yearMonth}' </if>
<if test= "year != null and year !='' " > and `year` ='${year}' </if>
<if test= "month != null and month !='' " > and `month` ='${month}' </if>
<if test= "ownCode != null and ownCode !='' " > and own_code='${ownCode}' </if>
<if test= "function != null and function !='' " > and function='${function}' </if>
<if test= "centerCode != null and centerCode !='' " > and center_code='${centerCode}' </if>
<if test= "relatedParty != null and relatedParty !='' " > and related_party='${relatedParty}' </if>
<if test= "bigRegion != null and bigRegion !='' " > and big_region='${bigRegion}' </if>
<if test= "dataIds != null and dataIds.size() > 0 " >
and id in
<foreach item='item' index='index' collection='dataIds' open='(' separator=',' close=')' >
'${item}'
</foreach>
</if>
order by 1,2,3,4,5,6,7,8,9,10,11,12
创建测试类,因为在代码中存在转义问题,所以我用以上SQL做了一个base64编码,然后再解析:
public static void main(String[] args) throws UnsupportedEncodingException {
MysqlDriverAdapter mysqlDriverAdapter = new MysqlDriverAdapter();
ParseParamsDTO params = mysqlDriverAdapter.parse("U0VMRUNUIAppZCBhcyBkYXRhSWQgLAp5ZWFyX21vbnRoIGFzIHllYXJNb250aCAsCmB5ZWFyYCBhcyBgeWVhcmAsCmBtb250aGAgYXMgYG1vbnRoYCwKb3duX2NvZGUgYXMgb3duQ29kZSwKZnVuY3Rpb24gYXMgZnVuY3Rpb24sCmNlbnRlcl9jb2RlIGFzIGNlbnRlckNvZGUsCmNlbnRlcl9uYW1lIGFzIGNlbnRlck5hbWUgLApyZWxhdGVkX3BhcnR5IGFzIHJlbGF0ZWRQYXJ0eSwKYmlnX3JlZ2lvbiBhcyBiaWdSZWdpb24sCm1hbmFnZW1lbnRfb3JnIGFzIG1hbmFnZW1lbnRPcmcsCmhyX29yZyBhcyBzck9yZwpmcm9tIGR3ZC5gdGVzdF9kZW1vYCB3aGVyZSAgMT0xIAo8aWYgdGVzdD0gInllYXJNb250aCAhPSBudWxsIGFuZCB5ZWFyTW9udGggIT0nJyAiID4gYW5kIHllYXJfbW9udGg9JyR7eWVhck1vbnRofScgPC9pZj4gCjxpZiB0ZXN0PSAieWVhciAhPSBudWxsIGFuZCB5ZWFyICE9JycgIiA+IGFuZCBgeWVhcmAgPScke3llYXJ9JyA8L2lmPiAKPGlmIHRlc3Q9ICJtb250aCAhPSBudWxsIGFuZCBtb250aCAhPScnICIgPiBhbmQgYG1vbnRoYCA9JyR7bW9udGh9JyA8L2lmPiAKPGlmIHRlc3Q9ICJvd25Db2RlICE9IG51bGwgYW5kIG93bkNvZGUgIT0nJyAiID4gYW5kIG93bl9jb2RlPScke293bkNvZGV9JyA8L2lmPiAKPGlmIHRlc3Q9ICJmdW5jdGlvbiAhPSBudWxsIGFuZCBmdW5jdGlvbiAhPScnICIgPiBhbmQgZnVuY3Rpb249JyR7ZnVuY3Rpb259JyA8L2lmPiAKPGlmIHRlc3Q9ICJjZW50ZXJDb2RlICE9IG51bGwgYW5kIGNlbnRlckNvZGUgIT0nJyAiID4gYW5kIGNlbnRlcl9jb2RlPScke2NlbnRlckNvZGV9JyA8L2lmPiAKPGlmIHRlc3Q9ICJyZWxhdGVkUGFydHkgIT0gbnVsbCBhbmQgcmVsYXRlZFBhcnR5ICE9JycgIiA+IGFuZCByZWxhdGVkX3BhcnR5PScke3JlbGF0ZWRQYXJ0eX0nIDwvaWY+IAo8aWYgdGVzdD0gImJpZ1JlZ2lvbiAhPSBudWxsIGFuZCBiaWdSZWdpb24gIT0nJyAiID4gYW5kIGJpZ19yZWdpb249JyR7YmlnUmVnaW9ufScgPC9pZj4gCjxpZiB0ZXN0PSAiZGF0YUlkcyAhPSBudWxsIGFuZCBkYXRhSWRzLnNpemUoKSA+IDAgIiA+IAphbmQgaWQgaW4gCjxmb3JlYWNoIGl0ZW09J2l0ZW0nIGluZGV4PSdpbmRleCcgY29sbGVjdGlvbj0nZGF0YUlkcycgb3Blbj0nKCcgc2VwYXJhdG9yPScsJyBjbG9zZT0nKScgPiAKJyR7aXRlbX0nICAgCjwvZm9yZWFjaD4gCjwvaWY+Cm9yZGVyIGJ5IDEsMiwzLDQsNSw2LDcsOCw5LDEwLDExLDEy");
System.out.println(JSON.toJSON(params));
}
返回结构:
{
"requestParams": [
{
"columnName": "ownCode",
"required": false
},
{
"columnName": "month",
"required": false
},
{
"columnName": "year",
"required": false
},
{
"columnName": "yearMonth",
"required": false
},
{
"columnName": "function",
"required": false
},
{
"columnName": "centerCode",
"required": false
},
{
"columnName": "bigRegion",
"required": false
},
{
"columnName": "dataIds",
"required": false
},
{
"columnName": "relatedParty",
"required": false
}
],
"returnParams": [
{
"columnName": "ownCode"
},
{
"columnName": "dataId"
},
{
"columnName": "month"
},
{
"columnName": "year"
},
{
"columnName": "yearMonth"
},
{
"columnName": "function"
},
{
"columnName": "managementOrg"
},
{
"columnName": "centerCode"
},
{
"columnName": "bigRegion"
},
{
"columnName": "relatedParty"
},
{
"columnName": "srOrg"
},
{
"columnName": "centerName"
}
]
}
最后:
以上就是一个动态MySQL解析工具类,欢迎指正。