简介
由于mybatis二级缓存存在关联查询无法及时更新的问题,无法直接使用,所以通过mybatis拦截器来实现mybatis二级缓存,其中的缓存实现是通过hutool中的缓存实现,可自定义缓存实现,仅提供思路。如有问题,欢迎提出。
实现原理
- 通过拦截StatementHandler接口下的query以及update方法来实现缓存的读写及更新操作;
- 通过拦截Executor接口下的commit、rollback、close方法来实现缓存的提交及回滚操作;
具体实现
- 定义一个mybaits插件抽象类AbstractPlugin。
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Properties;
/**
* mybatis插件抽象类
* @author gaolj
*/
public abstract class AbstractPlugin implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
//生成代理对象
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
/**
* 从代理对象中分离出真实对象
* @param target ivt --Invocation
* @return 非代理对象
*/
protected Object getUnProxyObject(Object target) {
MetaObject metaObject = SystemMetaObject.forObject(target);
//分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过循环可以分离出最原始的目标类)
Object object = null;
String hStr = "h";
while (metaObject.hasGetter(hStr)) {
object = metaObject.getValue(hStr);
metaObject = SystemMetaObject.forObject(object);
}
if (object == null) {
return target;
}
//获取最初的代理对象
while(object instanceof Plugin) {
object = metaObject.getValue("target");
metaObject = SystemMetaObject.forObject(object);
}
return object;
}
}
- 定义一个缓存的抽象类AbstractCachePlugin,在这个类中实现缓存的存储等方法。缓存分为全局缓存以及connection缓存。当connection提交时候,将connection缓存提交到全局缓存。如果被回滚,则清空connection缓存。通过net.sf.jsqlparser解析sql语句,来获取缓存更新依据(表)。
import cn.hutool.cache.Cache;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.CacheObj;
import cn.hutool.core.convert.Convert;
import cn.hutool.crypto.SecureUtil;
import cn.hutool.log.LogFactory;
import com.ctm.repair.persistent.plugin.AbstractPlugin;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.util.TablesNamesFinder;
import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* 缓存插件抽象类
* @author gaolj
*/
public abstract class AbstractCachePlugin extends AbstractPlugin {
/**
* 分号
*/
public final static String SEMICOLON = ";";
/**
* 全局缓存对象
*/
private final static Cache<String, Object> GLOBAL_CACHE = CacheUtil.newLRUCache(4096);
/**
* connection 缓存集合
*/
private final static Map<Connection, Cache<String, Object>> CONNECTION_CACHE_MAP = new ConcurrentHashMap<>();
/**
* 空表名列表
*/
private final static List<String> EMPTY_TABLE_NAME_LIST = new ArrayList<>();
private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock();
private final ReentrantReadWriteLock.ReadLock readLock = cacheLock.readLock();
private final ReentrantReadWriteLock.WriteLock writeLock = cacheLock.writeLock();
/**
* 获取缓存对象
* 如果connection中不存在缓存对象,则取全局缓存对象
* @param connection connection
* @return 缓存对象
*/
private Cache<String, Object> getCache(Connection connection) {
if (CONNECTION_CACHE_MAP.containsKey(connection)) {
return CONNECTION_CACHE_MAP.get(connection);
}
return GLOBAL_CACHE;
}
/**
* 判断是否存在缓存
* @param connection connection
* @param key key
* @return true:存在缓存;false:不存在缓存
*/
public boolean cacheContainsKey(Connection connection, String key) {
readLock.lock();
try {
return this.getCache(connection).containsKey(key);
} finally {
readLock.unlock();
}
}
/**
* 获取缓存值
* @param connection connection
* @param key key
* @return 缓存值
*/
public Object getCacheValue(Connection connection, String key) {
readLock.lock();
try {
return this.getCache(connection).get(key);
} finally {
LogFactory.get().debug("get results from cache,cache key: " + key);
readLock.unlock();
}
}
/**
* 设置缓存值
* @param connection connection
* @param key 缓存key
* @param value 缓存值
*/
public void setCacheValue(Connection connection, String key, Object value) {
writeLock.lock();
try {
if (!CONNECTION_CACHE_MAP.containsKey(connection)) {
CONNECTION_CACHE_MAP.put(connection, CacheUtil.newLRUCache(1024));
}
Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
cache.put(key, value);
} finally {
LogFactory.get().debug("write connection cache,cache key: " + key);
writeLock.unlock();
}
}
/**
* 刷新缓存
* @param sql sql语句
*/
public void flushCache(Connection connection, String sql) {
writeLock.lock();
try {
List<String> tableNameList = this.getTableNameList(sql);
if (tableNameList == null || tableNameList.size() <= 0) {
return;
}
//遍历刷新缓存
for (String tableName : tableNameList) {
String tableNameKey = SEMICOLON + tableName.toLowerCase() + SEMICOLON;
//清除connection缓存
if (CONNECTION_CACHE_MAP.containsKey(connection)) {
Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
Iterator<CacheObj<String, Object>> iterator = cache.cacheObjIterator();
List<String> keys = new ArrayList<>();
while(iterator.hasNext()) {
CacheObj<String, Object> cacheObj = iterator.next();
if (cacheObj.getKey().contains(tableNameKey)) {
keys.add(cacheObj.getKey());
}
}
for (String key : keys) {
cache.remove(key);
LogFactory.get().debug("remove connection cache when update table, table name: " + tableName + ", key: " + key);
}
}
//清除全局缓存
Iterator<CacheObj<String, Object>> iterator = GLOBAL_CACHE.cacheObjIterator();
while(iterator.hasNext()) {
CacheObj<String, Object> cacheObj = iterator.next();
String key = cacheObj.getKey();
GLOBAL_CACHE.remove(key);
LogFactory.get().debug("remove global cache when update table, table name: " + tableName + ", key: " + key);
if (GLOBAL_CACHE.containsKey(key)) {
//删除失败,清空全部缓存,退出
GLOBAL_CACHE.clear();
LogFactory.get().debug("remove global all cache because of fail delete when update table, table name: " + tableName + ", key: " + key);
break;
}
}
}
} catch (Exception e) {
LogFactory.get().warn("flush cache fail:\r\n" + e.getMessage());
GLOBAL_CACHE.clear();
if (CONNECTION_CACHE_MAP.containsKey(connection)) {
CONNECTION_CACHE_MAP.get(connection).clear();
}
LogFactory.get().debug("flush all cache!");
} finally {
writeLock.unlock();
}
}
/**
* 获取表名list集合
* @param sql sql
* @return list集合
*/
public List<String> getTableNameList(String sql) throws JSQLParserException {
net.sf.jsqlparser.statement.Statement statement = CCJSqlParserUtil.parse(sql);
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
List<String> tableNameList = tablesNamesFinder.getTableList(statement);
return Optional.ofNullable(tableNameList).orElse(EMPTY_TABLE_NAME_LIST);
}
/**
* 回滚缓存
* 抛弃当前连接的缓存
* @param connection connection
*/
public void rollbackCache(Connection connection) {
writeLock.lock();
try {
if (CONNECTION_CACHE_MAP.containsKey(connection)) {
CONNECTION_CACHE_MAP.remove(connection);
LogFactory.get().debug("connection cache remove!");
}
} finally {
writeLock.unlock();
}
}
/**
* 提交缓存
* 将connection中的缓存提交至全局缓存中
* @param connection connection
*/
public void commitCache(Connection connection) {
writeLock.lock();
try {
if (CONNECTION_CACHE_MAP.containsKey(connection)) {
Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
Iterator<CacheObj<String, Object>> iterator = cache.cacheObjIterator();
while (iterator.hasNext()) {
CacheObj<String, Object> cacheObj = iterator.next();
if (GLOBAL_CACHE.containsKey(cacheObj.getKey())) {
GLOBAL_CACHE.remove(cacheObj.getKey());
}
GLOBAL_CACHE.put(cacheObj.getKey(), cacheObj.getValue());
LogFactory.get().debug("global cache write,write key: " + cacheObj.getKey());
}
CONNECTION_CACHE_MAP.remove(connection);
LogFactory.get().debug("connection cache remove!");
}
} finally {
writeLock.unlock();
}
}
/**
* 获取缓存key
* @param id 查询id
* @param bytes 参数序列字节数组
* @param sql sql语句
* @return key
*/
public String getCacheKey(String id, byte[] bytes, String sql, Set<String> ignore) throws JSQLParserException {
StringBuilder key = new StringBuilder(SecureUtil.md5(id + SEMICOLON + Convert.toStr(bytes)));
List<String> tableNameList = this.getTableNameList(sql);
key.append(SEMICOLON);
for (String tableName : tableNameList) {
if (ignore.contains(tableName)) {
return null;
}
key.append(tableName.toLowerCase()).append(SEMICOLON);
}
return key.toString();
}
}
- 实现缓存读写插件CacheReadWritePlugin。在查询sql的时候,去判断是否存在缓存,如果存在缓存,则获取缓存值,否则正常查询,然后将缓存写入到connection缓存中。在更新sql的时候,去进行更新缓存。
注: 缓存的key是通过序列化参数值来判断是否为相同的查询,所以所有影响查询结果的参数都必须为可序列化。
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.log.LogFactory;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
/**
* mybatis缓存读写插件
* @author gaolj
*/
@Intercepts({
@Signature(
type = StatementHandler.class,
method = "query",
args = {Statement.class, ResultHandler.class}
),
@Signature(
type = StatementHandler.class,
method = "update",
args = {Statement.class}
)
})
public class CacheReadWritePlugin extends AbstractCachePlugin implements Interceptor {
/**
* 查询方法名称
*/
private final static String METHOD_NAME_QUERY = "query";
/**
* 更新方法名称
*/
private final static String METHOD_NAME_UPDATE = "update";
/**
* 忽略表名set集合
*/
private static Set<String> ignoreSet = new ConcurrentHashSet<>(8);
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
String invocationMethodName = invocation.getMethod().getName();
if (METHOD_NAME_QUERY.equals(invocationMethodName)) {
return executeQuery(invocation);
} else if (METHOD_NAME_UPDATE.equals(invocationMethodName)) {
return executeUpdate(invocation);
} else {
return invocation.proceed();
}
} catch (Exception ignore) {
LogFactory.get().warn("cache get or update fail:\r\n" + ignore.getMessage());
return invocation.proceed();
}
}
/**
* 执行查询
* @param invocation invocation
* @return object
*/
private Object executeQuery(Invocation invocation) throws Exception {
Statement statement = (Statement)invocation.getArgs()[0];
Connection connection = statement.getConnection();
StatementHandler stmtHandler = (StatementHandler) getUnProxyObject(invocation.getTarget());
MetaObject metaStatementHandler = SystemMetaObject.forObject(stmtHandler);
MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
Object parameterObject = boundSql.getParameterObject();
//获取缓存key
String id = mappedStatement.getId();
byte[] bytes = ObjectUtil.serialize(parameterObject);
String key = this.getCacheKey(id, bytes, boundSql.getSql(), ignoreSet);
if (key != null) {
//key为null,不执行缓存读写操作
if (this.cacheContainsKey(connection, key)) {
return this.getCacheValue(connection, key);
} else {
//缓存不存在
Object obj = invocation.proceed();
this.setCacheValue(connection ,key, obj);
return obj;
}
}
return invocation.proceed();
}
/**
* 执行更新
* @param invocation invocation
* @return Object
*/
private Object executeUpdate(Invocation invocation) throws InvocationTargetException, IllegalAccessException, SQLException {
Statement statement = (Statement)invocation.getArgs()[0];
Connection connection = statement.getConnection();
StatementHandler stmtHandler = (StatementHandler) getUnProxyObject(invocation.getTarget());
MetaObject metaStatementHandler = SystemMetaObject.forObject(stmtHandler);
BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
String sql = boundSql.getSql();
this.flushCache(connection, sql);
return invocation.proceed();
}
@Override
public void setProperties(Properties properties) {
String ignore = properties.getProperty("ignore");
if (StrUtil.isNotEmpty(ignore)) {
ignoreSet.addAll(Arrays.asList(ignore.split(SEMICOLON)));
}
}
}
- 缓存提交回滚插件CacheCommitRollbackPlugin实现。在事务提交的时候,进行提交缓存,在事务回滚的时候,清除缓存。当连接关闭的时候,需要判断是否回滚。
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import java.sql.Connection;
/**
* mybatis缓存提交回滚插件
* @author gaolj
*/
@Intercepts({
@Signature(
type = Executor.class,
method = "commit",
args = {boolean.class}
),
@Signature(
type = Executor.class,
method = "rollback",
args = {boolean.class}
),
@Signature(
type = Executor.class,
method = "close",
args = {boolean.class}
)
})
public class CacheCommitRollbackPlugin extends AbstractCachePlugin implements Interceptor {
/**
* 提交方法名称
*/
private final static String METHOD_NAME_COMMIT = "commit";
/**
* 回滚方法名称
*/
private final static String METHOD_NAME_ROLLBACK = "rollback";
/**
* 关闭方法名称
*/
private final static String METHOD_NAME_CLOSE = "close";
@Override
public Object intercept(Invocation invocation) throws Throwable {
Executor executor = (Executor) getUnProxyObject(invocation.getTarget());
Connection connection = executor.getTransaction().getConnection();
String invocationMethodName = invocation.getMethod().getName();
if (METHOD_NAME_COMMIT.equals(invocationMethodName)) {
this.commitCache(connection);
} else if (METHOD_NAME_ROLLBACK.equals(invocationMethodName)) {
this.rollbackCache(connection);
} else if (METHOD_NAME_CLOSE.equals(invocationMethodName)) {
boolean isRollback = (boolean) invocation.getArgs()[0];
if (isRollback) {
this.rollbackCache(connection);
} else {
this.commitCache(connection);
}
}
return invocation.proceed();
}
}
- 将插件CacheReadWritePlugin、CacheCommitRollbackPlugin配置到mybatis插件中去,即可使用。其中可在CacheReadWritePlugin插件下配置参数"ignore"来过滤掉不需要缓存的表,多张表使用";"分割,例如:
<plugin interceptor="package.plugin.cache.CacheReadWritePlugin">
<property name="ignore" value="table1;table2"/>
</plugin>
注: 需要将mybatis二级缓存关闭,要不无法正常使用。