mybatis二级缓存实现

mybatis二级缓存实现

简介

由于mybatis二级缓存存在关联查询无法及时更新的问题,无法直接使用,所以通过mybatis拦截器来实现mybatis二级缓存,其中的缓存实现是通过hutool中的缓存实现,可自定义缓存实现,仅提供思路。如有问题,欢迎提出。

实现原理

  1. 通过拦截StatementHandler接口下的query以及update方法来实现缓存的读写及更新操作;
  2. 通过拦截Executor接口下的commitrollbackclose方法来实现缓存的提交及回滚操作;

具体实现

  1. 定义一个mybaits插件抽象类AbstractPlugin
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Properties;

/**
 * mybatis插件抽象类
 * @author gaolj
 */
public abstract class AbstractPlugin implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        //生成代理对象
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    /**
     * 从代理对象中分离出真实对象
     * @param target ivt --Invocation
     * @return 非代理对象
     */
    protected Object getUnProxyObject(Object target) {
        MetaObject metaObject = SystemMetaObject.forObject(target);
        //分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过循环可以分离出最原始的目标类)
        Object object = null;
        String hStr = "h";
        while (metaObject.hasGetter(hStr)) {
            object = metaObject.getValue(hStr);
            metaObject = SystemMetaObject.forObject(object);
        }
        if (object == null) {
            return target;
        }
        //获取最初的代理对象
        while(object instanceof Plugin) {
            object = metaObject.getValue("target");
            metaObject = SystemMetaObject.forObject(object);
        }
        return object;
    }
    
}
  1. 定义一个缓存的抽象类AbstractCachePlugin,在这个类中实现缓存的存储等方法。缓存分为全局缓存以及connection缓存。当connection提交时候,将connection缓存提交到全局缓存。如果被回滚,则清空connection缓存。通过net.sf.jsqlparser解析sql语句,来获取缓存更新依据(表)。

import cn.hutool.cache.Cache;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.CacheObj;
import cn.hutool.core.convert.Convert;
import cn.hutool.crypto.SecureUtil;
import cn.hutool.log.LogFactory;
import com.ctm.repair.persistent.plugin.AbstractPlugin;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.util.TablesNamesFinder;

import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
* 缓存插件抽象类
* @author gaolj
*/
public abstract class AbstractCachePlugin extends AbstractPlugin {

   /**
    * 分号
    */
   public final static String SEMICOLON = ";";

   /**
    * 全局缓存对象
    */
   private final static Cache<String, Object> GLOBAL_CACHE = CacheUtil.newLRUCache(4096);

   /**
    * connection 缓存集合
    */
   private final static Map<Connection, Cache<String, Object>> CONNECTION_CACHE_MAP = new ConcurrentHashMap<>();

   /**
    * 空表名列表
    */
   private final static List<String> EMPTY_TABLE_NAME_LIST = new ArrayList<>();

   private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock();
   private final ReentrantReadWriteLock.ReadLock readLock = cacheLock.readLock();
   private final ReentrantReadWriteLock.WriteLock writeLock = cacheLock.writeLock();

   /**
    * 获取缓存对象
    * 如果connection中不存在缓存对象,则取全局缓存对象
    * @param connection    connection
    * @return  缓存对象
    */
   private Cache<String, Object> getCache(Connection connection) {
       if (CONNECTION_CACHE_MAP.containsKey(connection)) {
           return CONNECTION_CACHE_MAP.get(connection);
       }
       return GLOBAL_CACHE;
   }

   /**
    * 判断是否存在缓存
    * @param connection connection
    * @param key   key
    * @return  true:存在缓存;false:不存在缓存
    */
   public boolean cacheContainsKey(Connection connection, String key) {
       readLock.lock();
       try {
           return this.getCache(connection).containsKey(key);
       } finally {
           readLock.unlock();
       }
   }

   /**
    * 获取缓存值
    * @param connection    connection
    * @param key   key
    * @return  缓存值
    */
   public Object getCacheValue(Connection connection, String key) {
       readLock.lock();
       try {
           return this.getCache(connection).get(key);
       } finally {
           LogFactory.get().debug("get results from cache,cache key: " + key);
           readLock.unlock();
       }
   }

   /**
    * 设置缓存值
    * @param connection    connection
    * @param key   缓存key
    * @param value   缓存值
    */
   public void setCacheValue(Connection connection, String key, Object value) {
       writeLock.lock();
       try {
           if (!CONNECTION_CACHE_MAP.containsKey(connection)) {
               CONNECTION_CACHE_MAP.put(connection, CacheUtil.newLRUCache(1024));
           }
           Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
           cache.put(key, value);
       } finally {
           LogFactory.get().debug("write connection cache,cache key: " + key);
           writeLock.unlock();
       }
   }

   /**
    * 刷新缓存
    * @param sql sql语句
    */
   public void flushCache(Connection connection, String sql) {
       writeLock.lock();
       try {
           List<String> tableNameList = this.getTableNameList(sql);
           if (tableNameList == null || tableNameList.size() <= 0) {
               return;
           }
           //遍历刷新缓存
           for (String tableName : tableNameList) {
               String tableNameKey = SEMICOLON + tableName.toLowerCase() + SEMICOLON;
               //清除connection缓存
               if (CONNECTION_CACHE_MAP.containsKey(connection)) {
                   Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
                   Iterator<CacheObj<String, Object>> iterator = cache.cacheObjIterator();
                   List<String> keys = new ArrayList<>();
                   while(iterator.hasNext()) {
                       CacheObj<String, Object> cacheObj = iterator.next();
                       if (cacheObj.getKey().contains(tableNameKey)) {
                           keys.add(cacheObj.getKey());
                       }
                   }
                   for (String key : keys) {
                       cache.remove(key);
                       LogFactory.get().debug("remove connection cache when update table, table name: " + tableName + ", key: " + key);
                   }
               }
               //清除全局缓存
               Iterator<CacheObj<String, Object>> iterator = GLOBAL_CACHE.cacheObjIterator();
               while(iterator.hasNext()) {
                   CacheObj<String, Object> cacheObj = iterator.next();
                   String key = cacheObj.getKey();
                   GLOBAL_CACHE.remove(key);
                   LogFactory.get().debug("remove global cache when update table, table name: " + tableName + ", key: " + key);
                   if (GLOBAL_CACHE.containsKey(key)) {
                       //删除失败,清空全部缓存,退出
                       GLOBAL_CACHE.clear();
                       LogFactory.get().debug("remove global all cache because of fail delete when update table, table name: " + tableName + ", key: " + key);
                       break;
                   }
               }
           }
       } catch (Exception e) {
           LogFactory.get().warn("flush cache fail:\r\n" + e.getMessage());
           GLOBAL_CACHE.clear();
           if (CONNECTION_CACHE_MAP.containsKey(connection)) {
               CONNECTION_CACHE_MAP.get(connection).clear();
           }
           LogFactory.get().debug("flush all cache!");
       } finally {
           writeLock.unlock();
       }
   }

   /**
    * 获取表名list集合
    * @param sql   sql
    * @return  list集合
    */
   public List<String> getTableNameList(String sql) throws JSQLParserException {
       net.sf.jsqlparser.statement.Statement statement = CCJSqlParserUtil.parse(sql);
       TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
       List<String> tableNameList = tablesNamesFinder.getTableList(statement);
       return Optional.ofNullable(tableNameList).orElse(EMPTY_TABLE_NAME_LIST);
   }

   /**
    * 回滚缓存
    * 抛弃当前连接的缓存
    * @param connection connection
    */
   public void rollbackCache(Connection connection) {
       writeLock.lock();
       try {
           if (CONNECTION_CACHE_MAP.containsKey(connection)) {
               CONNECTION_CACHE_MAP.remove(connection);
               LogFactory.get().debug("connection cache remove!");
           }
       } finally {
           writeLock.unlock();
       }
   }

   /**
    * 提交缓存
    * 将connection中的缓存提交至全局缓存中
    * @param connection    connection
    */
   public void commitCache(Connection connection) {
       writeLock.lock();
       try {
           if (CONNECTION_CACHE_MAP.containsKey(connection)) {
               Cache<String, Object> cache = CONNECTION_CACHE_MAP.get(connection);
               Iterator<CacheObj<String, Object>> iterator = cache.cacheObjIterator();
               while (iterator.hasNext()) {
                   CacheObj<String, Object> cacheObj = iterator.next();
                   if (GLOBAL_CACHE.containsKey(cacheObj.getKey())) {
                       GLOBAL_CACHE.remove(cacheObj.getKey());
                   }
                   GLOBAL_CACHE.put(cacheObj.getKey(), cacheObj.getValue());
                   LogFactory.get().debug("global cache write,write key: " + cacheObj.getKey());
               }
               CONNECTION_CACHE_MAP.remove(connection);
               LogFactory.get().debug("connection cache remove!");
           }
       } finally {
           writeLock.unlock();
       }
   }

   /**
    * 获取缓存key
    * @param id    查询id
    * @param bytes 参数序列字节数组
    * @param sql   sql语句
    * @return  key
    */
   public String getCacheKey(String id, byte[] bytes, String sql, Set<String> ignore) throws JSQLParserException {
       StringBuilder key = new StringBuilder(SecureUtil.md5(id + SEMICOLON + Convert.toStr(bytes)));
       List<String> tableNameList = this.getTableNameList(sql);
       key.append(SEMICOLON);
       for (String tableName : tableNameList) {
           if (ignore.contains(tableName)) {
               return null;
           }
           key.append(tableName.toLowerCase()).append(SEMICOLON);
       }
       return key.toString();
   }

}
  1. 实现缓存读写插件CacheReadWritePlugin。在查询sql的时候,去判断是否存在缓存,如果存在缓存,则获取缓存值,否则正常查询,然后将缓存写入到connection缓存中。在更新sql的时候,去进行更新缓存。
    注: 缓存的key是通过序列化参数值来判断是否为相同的查询,所以所有影响查询结果的参数都必须为可序列化。

import cn.hutool.core.util.ObjectUtil;
import cn.hutool.log.LogFactory;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;

import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;

/**
 * mybatis缓存读写插件
 * @author gaolj
 */
@Intercepts({
		@Signature(
				type = StatementHandler.class,
				method = "query",
				args = {Statement.class, ResultHandler.class}
		),
		@Signature(
				type = StatementHandler.class,
				method = "update",
				args = {Statement.class}
		)
})
public class CacheReadWritePlugin extends AbstractCachePlugin implements Interceptor {

	/**
	 * 查询方法名称
	 */
	private final static String METHOD_NAME_QUERY = "query";

	/**
	 * 更新方法名称
	 */
	private final static String METHOD_NAME_UPDATE = "update";

	/**
	 * 忽略表名set集合
	 */
	private static Set<String> ignoreSet = new ConcurrentHashSet<>(8);

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		try {
			String invocationMethodName = invocation.getMethod().getName();
			if (METHOD_NAME_QUERY.equals(invocationMethodName)) {
				return executeQuery(invocation);
			} else if (METHOD_NAME_UPDATE.equals(invocationMethodName)) {
				return executeUpdate(invocation);
			} else {
				return invocation.proceed();
			}
		} catch (Exception ignore) {
			LogFactory.get().warn("cache get or update fail:\r\n" + ignore.getMessage());
			return invocation.proceed();
		}
	}

	/**
	 * 执行查询
	 * @param invocation	invocation
	 * @return	object
	 */
	private Object executeQuery(Invocation invocation) throws Exception {
		Statement statement = (Statement)invocation.getArgs()[0];
		Connection connection = statement.getConnection();

		StatementHandler stmtHandler = (StatementHandler) getUnProxyObject(invocation.getTarget());
		MetaObject metaStatementHandler = SystemMetaObject.forObject(stmtHandler);
		MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");

		BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
		Object parameterObject = boundSql.getParameterObject();

		//获取缓存key
		String id = mappedStatement.getId();
		byte[] bytes = ObjectUtil.serialize(parameterObject);
		String key =  this.getCacheKey(id, bytes, boundSql.getSql(), ignoreSet);
		if (key != null) {
			//key为null,不执行缓存读写操作
			if (this.cacheContainsKey(connection, key)) {
				return this.getCacheValue(connection, key);
			} else {
				//缓存不存在
				Object obj = invocation.proceed();
				this.setCacheValue(connection ,key, obj);
				return obj;
			}
		}
		return invocation.proceed();
	}

	/**
	 * 执行更新
	 * @param invocation invocation
	 * @return Object
	 */
	private Object executeUpdate(Invocation invocation) throws InvocationTargetException, IllegalAccessException, SQLException {
		Statement statement = (Statement)invocation.getArgs()[0];
		Connection connection = statement.getConnection();

		StatementHandler stmtHandler = (StatementHandler) getUnProxyObject(invocation.getTarget());
		MetaObject metaStatementHandler = SystemMetaObject.forObject(stmtHandler);
		BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
		String sql = boundSql.getSql();
		this.flushCache(connection, sql);
		return invocation.proceed();
	}

	@Override
	public void setProperties(Properties properties) {
		String ignore = properties.getProperty("ignore");
		if (StrUtil.isNotEmpty(ignore)) {
			ignoreSet.addAll(Arrays.asList(ignore.split(SEMICOLON)));
		}
	}

}
  1. 缓存提交回滚插件CacheCommitRollbackPlugin实现。在事务提交的时候,进行提交缓存,在事务回滚的时候,清除缓存。当连接关闭的时候,需要判断是否回滚。

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;

import java.sql.Connection;

/**
 * mybatis缓存提交回滚插件
 * @author gaolj
 */
@Intercepts({
		@Signature(
				type = Executor.class,
				method = "commit",
				args = {boolean.class}
		),
		@Signature(
				type = Executor.class,
				method = "rollback",
				args = {boolean.class}
		),
		@Signature(
				type = Executor.class,
				method = "close",
				args = {boolean.class}
		)
})
public class CacheCommitRollbackPlugin extends AbstractCachePlugin implements Interceptor {

	/**
	 * 提交方法名称
	 */
	private final static String METHOD_NAME_COMMIT = "commit";

	/**
	 * 回滚方法名称
	 */
	private final static String METHOD_NAME_ROLLBACK = "rollback";

	/**
	 * 关闭方法名称
	 */
	private final static String METHOD_NAME_CLOSE = "close";

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		Executor executor = (Executor) getUnProxyObject(invocation.getTarget());
		Connection connection = executor.getTransaction().getConnection();
		String invocationMethodName = invocation.getMethod().getName();
		if (METHOD_NAME_COMMIT.equals(invocationMethodName)) {
			this.commitCache(connection);
		} else if (METHOD_NAME_ROLLBACK.equals(invocationMethodName)) {
			this.rollbackCache(connection);
		} else if (METHOD_NAME_CLOSE.equals(invocationMethodName)) {
			boolean isRollback = (boolean) invocation.getArgs()[0];
			if (isRollback) {
				this.rollbackCache(connection);
			} else {
				this.commitCache(connection);
			}
		}
		return invocation.proceed();
	}

}
  1. 将插件CacheReadWritePluginCacheCommitRollbackPlugin配置到mybatis插件中去,即可使用。其中可在CacheReadWritePlugin插件下配置参数"ignore"来过滤掉不需要缓存的表,多张表使用";"分割,例如:
		<plugin interceptor="package.plugin.cache.CacheReadWritePlugin">
			<property name="ignore" value="table1;table2"/>
		</plugin>

注: 需要将mybatis二级缓存关闭,要不无法正常使用。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值