MybatisInterceptor 自定义拦截器
* 拦截所有的update/insert/insertbatch操作
* 拦截方法执行前:自定义id插入并返回id。
(重点关注 insertbach类型时的解析:insertbatch时如何赋值并返回自定义的id)
* 拦截方法执行后:获取执行sql的表名和参数,处理业务逻辑
/**
* mybatis拦截器
* 拦截所有的update/insert/insertbatch操作
* 拦截方法执行前:自定义id插入并返回id(重点关注 insertbach类型时的解析)
* 拦截方法执行后:获取执行sql的表名和参数,处理业务逻辑
*/
@Slf4j
@Component
@Intercepts({ @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }) })
public class MybatisInterceptor implements Interceptor {
@Value("${buzhongyao.enabled}")
private boolean language;
@Override
public Object intercept(Invocation invocation) throws Throwable {
log.info("----进入mybatis拦截器----");
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
//获取执行方法的位置
String sqlId = mappedStatement.getId();
log.info("----------mapper名称+方法名" + sqlId);
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
Object parameter = invocation.getArgs()[1];
//log.debug("------sqlCommandType------" + sqlCommandType);
if (parameter == null) {
return invocation.proceed();
}
List<String> bizIds = new ArrayList<String>();
if (SqlCommandType.INSERT == sqlCommandType) {
log.debug("处理INSERT类型");
// 获取新增数据的业务id:拦截并手动插入id(insert单条测试通过,insertbatch 测试通过)
bizIds = executeInsert(parameter);
}
if (SqlCommandType.UPDATE == sqlCommandType) {
log.debug("处理UPDATE类型");
// 获取更新数据的业务id
String id = executeUpdate(parameter);
bizIds.add(id);
}
// 跳出拦截器,继续处理原业务
Object result = invocation.proceed();
// 拦截后执行代码块:保存多语言设置
log.info("mybatis语言配置开关:"+this.language);
if(this.language){
try {
if(null != bizIds && bizIds.size() >0){
for(String bizId : bizIds){
executeAfterProceed(parameter,bizId,mappedStatement);
}
}
} catch (Exception e) {
log.error("Mybatis拦截器后置处理异常 原因:" + e);
}
}
return result;
}
/**
* 拦截器后置处理
* @param parameter
* @param bizId
* @param mappedStatement
* @throws Exception
*/
private void executeAfterProceed(Object parameter,String bizId,MappedStatement mappedStatement) throws Exception {
log.info("拦截器后置处理(语言)--------业务id:"+bizId);
// 多语言
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
Configuration configuration= mappedStatement.getConfiguration();
String sql = showSql(configuration,boundSql);
//log.info("mybatis拦截器sql语句:"+sql);
// 获取sql中的表名
List<String> tablenames = getTableNames(sql);
log.info("获取sql中的表名:"+tablenames);
if(null != tablenames && tablenames.size()>0){
String tableName = tablenames.get(0);
// sys_language ,sys_log 开头的表不需要拦截配置语言
if(tableName.indexOf("sys_language") >= 0 || tableName.indexOf("sys_log") >= 0){
return;
}else{
List<SysLanguage> reqSysLanguageList = new ArrayList<SysLanguage>();
if (parameter instanceof ParamMap) {
// update时这样
ParamMap<?> p = (ParamMap<?>) parameter;
if (p.containsKey("et")) {
parameter = p.get("et");
}else if(p.containsKey("param1")) {
parameter = p.get("param1");
}else{
// insertbatch时
List params = (List)p.get("list");
for(Object obj : params){
Field fieldId = obj.getClass().getDeclaredField("id");
fieldId.setAccessible(true);
String dtl_id = (String)fieldId.get(obj);
fieldId.setAccessible(false);
if(bizId.equals(dtl_id)){
parameter = obj;
break;
}
}
}
}
String reqJson = JSONObject.toJSONString(parameter);
if(reqJson.contains("sysLanguageList")){
Field field = parameter.getClass().getSuperclass().getDeclaredField("sysLanguageList");
if(null != field){
field.setAccessible(true);
reqSysLanguageList = (List<SysLanguage>) field.get(parameter);
field.setAccessible(false);
}
}else{
log.info("没有语言请求配置信息,返回");
return;
}
// 创建bean对象
//SysLanguageMapper sysLanguageMapper = (SysLanguageMapper) SpringContextHolder.getBean("sysLanguageMapper");
}
}
}
/**
* 处理insert类型
* @param parameter
*/
private List<String> executeInsert(Object parameter) {
Field[] fields = null;// insert单个表
List<Field[]> allFieldList = null;// insertbatch多个表
List<String> ids = new ArrayList<String>();
if (parameter instanceof ParamMap) {
ParamMap<?> p = (ParamMap<?>) parameter;
//update-begin-for:批量更新报错issues/IZA3Q--
if (p.containsKey("et")) {
parameter = p.get("et");
fields = oConvertUtils.getAllFields(parameter);
}else if(p.containsKey("param1")) {
parameter = p.get("param1");
fields = oConvertUtils.getAllFields(parameter);
}else{
// insertbatch时
parameter = p.get("list");
allFieldList = oConvertUtils.getListAllFields(parameter);
}
//update-end- for:批量更新报错issues/IZA3Q-
} else {
fields = oConvertUtils.getAllFields(parameter);
}
if(null != fields){
String id = null;
for (Field field : fields) {
id = excuteInsertField(parameter,field,id);
if(null != id){
ids.add(id);
id = null;// id 重新置空(避免重复)
}
}
return ids;
}
if(null != allFieldList){
for(int i=0;i<allFieldList.size();i++){
Field[] lfields = allFieldList.get(i);
String id = null;
for (Field field : lfields) {
id = excuteInsertField(((List)parameter).get(i),field,id);
if(null != id){
ids.add(id);
id = null;// id 重新置空(避免重复)
}
}
}
return ids;
}
return ids;
}
private String excuteInsertField(Object parameter,Field field,String id){
log.debug("------field.name------" + field.getName());
try {
if ("id".equals(field.getName())) {
id = UUIDGenerator.getId();
field.setAccessible(true);
field.set(parameter, id);
//log.info("拦截器前置处理(多语言)Insert--Id ="+id);
field.setAccessible(false);
}
} catch (Exception e) {
log.debug(e.getMessage());
id = null;
}
return id;
}
/**
* 处理Update类型
* @param parameter
* @return
*/
private String executeUpdate(Object parameter){
String id = null;
Field[] fields = null;
if (parameter instanceof ParamMap) {
ParamMap<?> p = (ParamMap<?>) parameter;
//update-begin-for:批量更新报错issues/IZA3Q--
if (p.containsKey("et")) {
parameter = p.get("et");
} else {
parameter = p.get("param1");
}
//update-end- for:批量更新报错issues/IZA3Q-
//update-begin- for:更新指定字段时报错 issues/#516-
/*if (parameter == null) {
invocation.proceed()
}*/
//update-end- for:更新指定字段时报错 issues/#516-
fields = oConvertUtils.getAllFields(parameter);
} else {
fields = oConvertUtils.getAllFields(parameter);
}
for (Field field : fields) {
//log.debug("------field.name------" + field.getName());
try {
if ("id".equals(field.getName())) {
field.setAccessible(true);
Object local_id = field.get(parameter);
id = (String)local_id;
//log.info("拦截器前置处理(多语言)Update Id ="+local_id);
field.setAccessible(false);
}
} catch (Exception e) {
e.printStackTrace();
}
}
return id;
}
//update-begin--for:关于使用Quzrtz 开启线程任务, #465
private LoginUser getLoginUser() {
LoginUser sysUser = null;
try {
sysUser = SecurityUtils.getSubject().getPrincipal() != null ? (LoginUser) SecurityUtils.getSubject().getPrincipal() : null;
} catch (Exception e) {
//e.printStackTrace();
sysUser = null;
}
return sysUser;
}
//update-end--for:关于使用Quzrtz 开启线程任务, #465
private static String getParameterValue(Object obj) {
String value = null;
if (obj instanceof String) {
value = "'" + obj.toString() + "'";
value = value.replaceAll("\\\\", "\\\\\\\\");
value = value.replaceAll("\\$", "\\\\\\$");
} else if (obj instanceof Date) {
DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
value = "'" + formatter.format(obj) + "'";
} else {
if (obj != null) {
value = obj.toString();
} else {
value = "''";
}
}
return value;
}
public static String showSql(Configuration configuration, BoundSql boundSql) {
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
if (parameterMappings.size() > 0 && parameterObject != null) {
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
if (metaObject.hasGetter(propertyName)) {
Object obj = metaObject.getValue(propertyName);
sql = sql.replaceFirst("\\?", getParameterValue(obj));
} else if (boundSql.hasAdditionalParameter(propertyName)) {
Object obj = boundSql.getAdditionalParameter(propertyName);
sql = sql.replaceFirst("\\?", getParameterValue(obj));
}
}
}
}
return sql;
}
/**
* detect table names from given table
* ATTENTION : WE WILL SKIP SCALAR SUBQUERY IN PROJECTION CLAUSE
* */
private static List<String> getTableNames(String sql) throws Exception {
CCJSqlParserManager pm = new CCJSqlParserManager();
List<String> tablenames = new ArrayList<String>();
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
Statement statement = pm.parse(new StringReader(sql));
if (statement instanceof Select) {
//tablenames = tablesNamesFinder.getTableList((Select) statement);
return null;
} else if (statement instanceof Update) {
Update updateStatement = (Update) statement;
String tablename = updateStatement.getTable().getName();
tablenames.add(tablename);
} else if (statement instanceof Delete) {
return null;
} else if (statement instanceof Replace) {
return null;
} else if (statement instanceof Insert) {
Insert insertStatement = (Insert) statement;
String tablename = insertStatement.getTable().getName();
tablenames.add(tablename);
}
return tablenames;
}
/**
* 作用:让mybatis判断,是否要进行拦截,然后做出决定是否生成一个代理
* @param target
* @return
*/
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 配置自定义属性变量
* @param properties
*/
@Override
public void setProperties(Properties properties) {
// TODO Auto-generated method stub
}
private static String getTenantCode(){
return "1";
}
public static void main(String[] args) throws Exception {
String sql ="随便写个sql例子吧";
List<String> tablenames = getTableNames(sql);
System.out.println(tablenames);
}
}
oConvertUtils.java
/**
* 获取类的所有属性,包括父类
* @param object
* @return
*/
public static Field[] getAllFields(Object object) {
List<Field> fieldList = new ArrayList<>();
Class<?> clazz = object.getClass();
while (clazz != null) {
fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
clazz = clazz.getSuperclass();
}
Field[] fields = new Field[fieldList.size()];
fieldList.toArray(fields);
return fields;
}
/**
* insertbatch 获取所有 明细对象的属性
* @param object
* @return
*/
public static List<Field[]> getListAllFields(Object object) {
List<Field[]> fieldList = new ArrayList();
if(object instanceof List){
Field[] fields = null;
for(Object o : (List)object){
fields =getAllFields(o);
fieldList.add(fields);
}
}
return fieldList;
}
pom.xml引入
<!--sql解析 多语言 -->
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>0.9</version>
</dependency>
参考链接: