Mybatis-Plus(https://github.com/baomidou/mybatis-plus)的乐观锁插件并不能实现更新失败时抛出指定异常,本博文针对此对3.0版本的乐观锁进行了改造,只贴关键代码。
简单介绍一下改造:当一次update发生时,拦截器首先判断是否有传版本号字段(本代码中是version_val,自行按照实际命名,判断逻辑较复杂,有兴趣的盆友可以看看),如果没有传版本号字段则恢复执行,如果有那么更新成功条数为0时,会将原条件去掉版本号字段后再查询一遍,如果查询结果为1,则说明是版本号不匹配导致的更新失败,则抛出乐观锁异常(可自行定制),如果查询结果为0则结束。
import com.alibaba.druid.util.StringUtils;
import com.baomidou.mybatisplus.annotation.Version;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.segments.MergeSegments;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.SimpleExecutor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.Statement;
import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
@Slf4j
public class OptimisticLockerExceptionInterceptor implements Interceptor {
/**
* 乐观锁常量
*
*
*/
@Deprecated
public static final String MP_OPTLOCK_VERSION_ORIGINAL = MybatisConstants.MP_OPTLOCK_VERSION_ORIGINAL;
/**
* 乐观锁常量
*
*
*/
@Deprecated
public static final String MP_OPTLOCK_VERSION_COLUMN = MybatisConstants.MP_OPTLOCK_VERSION_COLUMN;
/**
* 乐观锁常量
*
*
*/
@Deprecated
public static final String MP_OPTLOCK_ET_ORIGINAL = MybatisConstants.MP_OPTLOCK_ET_ORIGINAL;
private static final String NAME_ENTITY = Constants.ENTITY;
private static final String NAME_ENTITY_WRAPPER = Constants.WRAPPER;
private static final String PARAM_UPDATE_METHOD_NAME = "update";
private final Map<Class<?>, EntityField> versionFieldCache = new ConcurrentHashMap<>();
private final Map<Class<?>, List<EntityField>> entityFieldsCache = new ConcurrentHashMap<>();
/**
* 正则匹配键前缀
*/
private static final String EW_PARAMNAME_VALUE_PAIRS = "#{ew.paramNameValuePairs.";
/**
* 版本号字段(按照实际填写)
*/
private static final String VERSION_FIELD_NAME = "version_val";
/**
* 正则匹配version_val = #{...}
*/
private static final Pattern versionPattern = Pattern.compile(VERSION_FIELD_NAME + " = #\\{[^\\}]+\\}");
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public Object intercept(Invocation invocation) throws Throwable {
boolean hasVersionField;
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
return invocation.proceed();
}
Object param = args[1];
if (param instanceof Map) {
Map map = (Map) param;
//updateById(et), update(et, wrapper);
Object et = map.getOrDefault(NAME_ENTITY,null);
if (et != null) {
// entity
String methodId = ms.getId();
String methodName = methodId.substring(methodId.lastIndexOf(StringPool.DOT) + 1);
Class<?> entityClass = et.getClass();
TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
EntityField versionField = this.getVersionField(entityClass, tableInfo);
if (versionField == null) {
return invocation.proceed();
}
// 更新结果
Object resultObj;
Field field = versionField.getField();
Object originalVersionVal = versionField.getField().get(et);
if (originalVersionVal == null) {
Wrapper ew = (Wrapper) map.getOrDefault(NAME_ENTITY_WRAPPER,null);
if (ew == null) {
return invocation.proceed();
} else if (ew.getEntity() != null && ((BaseEntity)ew.getEntity()).getVersionVal() != null) {
originalVersionVal = ((BaseEntity)ew.getEntity()).getVersionVal();
} else if (ew.getSqlSegment() != null && ew.getSqlSegment().contains(versionField.getColumnName())) {
String sqlSegmentValue = ew.getSqlSegment();
Map<String, Object> pairsMap = ((AbstractWrapper) ew).getParamNameValuePairs();
Matcher matcher = versionPattern.matcher(sqlSegmentValue);
if (matcher.find()) {
String versionPair = matcher.group();
String versionKey = versionPair.substring(versionPair.indexOf(EW_PARAMNAME_VALUE_PAIRS)+EW_PARAMNAME_VALUE_PAIRS.length(),versionPair.length()-1);
if (pairsMap.get(versionKey) != null) {
originalVersionVal = pairsMap.get(versionKey);
}
}
} else {
return invocation.proceed();
}
}
if (originalVersionVal != null) {
hasVersionField = true;
} else {
// 再判断一次,确保update不会遗漏
return invocation.proceed();
}
Object updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
if (PARAM_UPDATE_METHOD_NAME.equals(methodName)) {
// update(entity, wrapper)
// mapper.update(updEntity, QueryWrapper<>(whereEntity);
AbstractWrapper<?, ?, ?> ew = (AbstractWrapper<?, ?, ?>) map.getOrDefault(NAME_ENTITY_WRAPPER, null);
if (ew == null) {
UpdateWrapper<?> uw = new UpdateWrapper<>();
uw.eq(versionField.getColumnName(), originalVersionVal);
map.put(NAME_ENTITY_WRAPPER, uw);
} else {
Field expressionField = getDeclaredField(ew.getClass(),"expression");
expressionField.setAccessible(true);
MergeSegments expression = (MergeSegments)expressionField.get(ew);
Field normalExpression = expression.getClass().getDeclaredField("normal");
normalExpression.setAccessible(true);
NormalSegmentList normalSegmentList = (NormalSegmentList)normalExpression.get(expression);
for (int i=0;i<normalSegmentList.size();i++) {
String s = normalSegmentList.get(i).getSqlSegment();
if (versionPattern.matcher(s).find()) {
Object sqlSegment = normalSegmentList.get(i);
Field arg$3 = getDeclaredField(sqlSegment.getClass(), "arg$3");
arg$3.setAccessible(true);
Object tt = arg$3.get(sqlSegment);
((Object[])tt)[0] = originalVersionVal;
}
}
ew.apply(versionField.getColumnName() + " = {0}", originalVersionVal);
}
field.set(et, updatedVersionVal);
resultObj = invocation.proceed();
} else {
List<EntityField> fields = entityFieldsCache.computeIfAbsent(entityClass, this::getFieldsFromClazz);
Map<String, Object> entityMap = new HashMap<>(fields.size());
for (EntityField ef : fields) {
Field fd = ef.getField();
entityMap.put(fd.getName(), fd.get(et));
}
String versionColumnName = versionField.getColumnName();
//update to cache
versionField.setColumnName(versionColumnName);
entityMap.put(field.getName(), updatedVersionVal);
entityMap.put(MybatisConstants.MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
entityMap.put(MybatisConstants.MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
entityMap.put(MybatisConstants.MP_OPTLOCK_ET_ORIGINAL, et);
map.put(NAME_ENTITY, entityMap);
resultObj = invocation.proceed();
}
if (resultObj != null && resultObj instanceof Integer) {
Integer effRow = (Integer) resultObj;
if (updatedVersionVal != null && effRow != 0) {
//updated version value set to entity.
field.set(et, updatedVersionVal);
}
else if (hasVersionField && effRow == 0) {
log.debug("有乐观锁");
Wrapper ew = (Wrapper) map.getOrDefault(NAME_ENTITY_WRAPPER, null);
// entity = null的情形
if (ew != null && ew.getEntity() == null && ew instanceof AbstractWrapper) {
// 查询是否由于乐观锁字段引起的update失败
AbstractWrapper updateWrapper = (AbstractWrapper) ew;
Map<String, Object> paramPairsMap = ((AbstractWrapper) ew).getParamNameValuePairs();
Field expressionField = getDeclaredField(updateWrapper.getClass(),"expression");
expressionField.setAccessible(true);
MergeSegments expression = (MergeSegments)expressionField.get(ew);
Field normalExpression = expression.getClass().getDeclaredField("normal");
normalExpression.setAccessible(true);
NormalSegmentList normalSegmentList = (NormalSegmentList)normalExpression.get(expression);
StringBuilder sqlBuilder = new StringBuilder();
sqlBuilder.append("SELECT COUNT(*) FROM ");
sqlBuilder.append(tableInfo.getTableName());
int versionIndex = -10;
for (int i=0;i<normalSegmentList.size();i++) {
if (i==0) {
sqlBuilder.append(" WHERE ");
}
String s = normalSegmentList.get(i).getSqlSegment();
if (versionPattern.matcher(s).find()) {
sqlBuilder.append(" 1 = 1 ");
continue;
}
if (s.equals(VERSION_FIELD_NAME)) {
versionIndex = i;
}
if (i == versionIndex || i == versionIndex+2) {
sqlBuilder.append(1);
} else {
if (s.startsWith(EW_PARAMNAME_VALUE_PAIRS)) {
s = s.substring(EW_PARAMNAME_VALUE_PAIRS.length(),s.length()-1);
if (paramPairsMap.get(s) != null) {
Object mapValue = paramPairsMap.get(s);
if (mapValue instanceof String || mapValue instanceof LocalDate || mapValue instanceof LocalDateTime) {
sqlBuilder.append("'");
sqlBuilder.append(mapValue.toString());
sqlBuilder.append("'");
} else {
sqlBuilder.append(mapValue.toString());
}
}
} else {
sqlBuilder.append(s);
}
}
sqlBuilder.append(" ");
}
String sql = sqlBuilder.toString();
if (sql.indexOf("WHERE") != -1) {
SimpleExecutor executor = (SimpleExecutor)invocation.getTarget();
Connection connection = executor.getTransaction().getConnection();
Statement st = connection.createStatement();
ResultSet selectResult = st.executeQuery(sql);
if (selectResult.next()) {
Object selectObject = selectResult.getObject(1);
BigDecimal bigDecimal = new BigDecimal(0);
if (selectObject instanceof Long) {
bigDecimal = new BigDecimal( (Long) selectObject);
} else if (selectObject instanceof Integer) {
bigDecimal = new BigDecimal( (Integer) selectObject);
}
if (bigDecimal.compareTo(BigDecimal.ZERO) > 0) {
if (selectResult.isClosed()) {
selectResult.close();
}
if (st.isClosed()) {
st.close();
}
if (!connection.isClosed()) {
connection.close();
}
throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
}
}
if (selectResult.isClosed()) {
selectResult.close();
}
if (st.isClosed()) {
st.close();
}
if (!connection.isClosed()) {
connection.close();
}
}
}
// entity != null的情形
else if (ew != null && ew.getEntity() != null){
BaseEntity entity = (BaseEntity) ew.getEntity();
entity.setVersionVal(null);
Object selectResult = ((BaseEntity)((MapperMethod.ParamMap) args[1]).get("et")).selectOne(new QueryWrapper().setEntity(entity));
if (selectResult != null) {
throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
}
}
// wrapper = null 的情形
else if (ew == null && et != null ) {
BaseEntity baseEntity = (BaseEntity) map.get("param1");
if (baseEntity != null) {
if (baseEntity.getVersionVal() != null) {
Object selectResult = baseEntity.selectById(baseEntity);
if (selectResult != null) {
throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
}
}
}
}
}
}
return resultObj;
}
}
return invocation.proceed();
}
/**
* This method provides the control for version value.<BR>
* Returned value type must be the same as original one.
*
* @param originalVersionVal ignore
* @return updated version val
*/
protected Object getUpdatedVersionVal(Object originalVersionVal) {
Class<?> versionValClass = originalVersionVal.getClass();
if (long.class.equals(versionValClass) || Long.class.equals(versionValClass)) {
return ((long) originalVersionVal) + 1;
} else if (int.class.equals(versionValClass) || Integer.class.equals(versionValClass)) {
return ((int) originalVersionVal) + 1;
} else if (Date.class.equals(versionValClass)) {
return new Date();
} else if (Timestamp.class.equals(versionValClass)) {
return new Timestamp(System.currentTimeMillis());
} else if (LocalDateTime.class.equals(versionValClass)) {
return LocalDateTime.now();
}
//not supported type, return original val.
return originalVersionVal;
}
@Override
public Object plugin(Object target) {
if (target instanceof Executor) {
return Plugin.wrap(target, this);
}
return target;
}
@Override
public void setProperties(Properties properties) {
// to do nothing
}
private EntityField getVersionField(Class<?> parameterClass, TableInfo tableInfo) {
return versionFieldCache.computeIfAbsent(parameterClass, mapping -> getVersionFieldRegular(parameterClass, tableInfo));
}
/**
* 反射检查参数类是否启动乐观锁
*
* @param parameterClass 实体类
* @param tableInfo 实体数据库反射信息
* @return ignore
*/
private EntityField getVersionFieldRegular(Class<?> parameterClass, TableInfo tableInfo) {
return Object.class.equals(parameterClass) ? null : ReflectionKit.getFieldList(parameterClass).stream().filter(e -> e.isAnnotationPresent(Version.class)).map(field -> {
field.setAccessible(true);
return new EntityField(field, true, tableInfo.getFieldList().stream().filter(e -> field.getName().equals(e.getProperty())).map(TableFieldInfo::getColumn).findFirst().orElse(null));
}).findFirst().orElseGet(() -> this.getVersionFieldRegular(parameterClass.getSuperclass(), tableInfo));
}
private List<EntityField> getFieldsFromClazz(Class<?> parameterClass) {
return ReflectionKit.getFieldList(parameterClass).stream().map(field -> {
field.setAccessible(true);
return new EntityField(field, field.isAnnotationPresent(Version.class));
}).collect(Collectors.toList());
}
@Data
private class EntityField {
private Field field;
private boolean version;
private String columnName;
EntityField(Field field, boolean version) {
this.field = field;
this.version = version;
}
public EntityField(Field field, boolean version, String columnName) {
this.field = field;
this.version = version;
this.columnName = columnName;
}
}
private Field getDeclaredField(Class<?> clazz, String fieldName) {
try {
if (clazz.getDeclaredField(fieldName) != null) {
return clazz.getDeclaredField(fieldName);
}
} catch (NoSuchFieldException e) {
clazz = clazz.getSuperclass();
}
return getDeclaredField(clazz, fieldName);
}
}