问题描述:
项目开发中由于数据量非常大(和某些表设计的不太合理),导致权限比较高的账号进行查询数据的时候SQL语句 in (....)的字段(id/relatedid)比较多,导致SQL查询效率比较慢,耗时较高
原因分析:
由于业务需要,只能通过优化mysql查询的方式来处理。1,修改mysql源码(暂不支持,人力物力对于小公司消耗较大)2,通过在查询语句开始前新建临时表的方式来优化in查询语句
select * from table where relatedID in (id1,id2...id100000) 优化成 select * from table where relatedID in (select id from relatedIds);relatedis是系统新建的临时表
解决思路:
1,由于需要创建临时表,则要用到mybatis的Interceptor接口,处理sql
2,需要拿到in(ids)里面的集合,就得用到spring aop
解决方案:
1,创建一个自动以注解(参数类型,临时表名称,字段名称)
2,在需要用临时表处理的方法上添加自定义注解
3,aop在方法调用前后拦截,处理自定义注解里面的参数,并且给一个全局类变量赋值(主注意线程安全,可以用ThreadLocal处理)
4,方法里调用mybatis的方法就可以被Interceptor拦截到,根据aop解析的参数,删除临时表,创建临时表,往临时表中批量插入数据
具体思路就上面步骤,接下来就是贴源码了,思路比源码重要,感兴趣的可以根据思路去实现代码
1,自定义注解
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/***
*
* 自定义注解
*
**/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TempTable {
ParamTypeEM paramType();
String tableName() default "tmp_table";
String fieldName() default "relatedIds";
int paramIndex() default 0;
}
2,拦截自定义注解,解析参数
@Aspect
@Component
@Order(10)
public class TempTableInteceptor {
private static final Logger logger = LoggerFactory.getLogger(TempTableInteceptor.class);
/**
* 定义拦截规则
*/
@Pointcut("@annotation(com.xxx.scs.common.inteceptor.annotation.TempTable)")
public void webPointcut(){}
@Before("webPointcut()")
public void before(JoinPoint joinPoint) throws Exception {
//获取方法参数
MethodSignature signature=(MethodSignature) joinPoint.getSignature();
Method method=signature.getMethod();
//拿到自定义注解参数信息
TempTable action=method.getAnnotation(TempTable.class);
ParamTypeEM paramType = action.paramType();
TempTableContext.pType.set(paramType);
logger.debug("paramType:{}",paramType);
Object[] args = joinPoint.getArgs();
if(action.paramIndex()+1>args.length){
logger.error("<<<<<<<<<<<<<<<<<paramIndex设置错误>>>>>>>>>>>>>>");
throw new RuntimeException("paramIndex设置错误!!");
}
Object param = args[action.paramIndex()];
Class<? extends Object> class1 = param.getClass();
String fieldName = action.fieldName();
if(param!=null){
switch (paramType) {
case MULTI:
if(Help.isNull(fieldName)){
throw new RuntimeException("fieldName参数不能为空!!");
}
List<MuiltiIDS> idsList = new ArrayList<MuiltiIDS>();
String[] fieldNames = fieldName.split(",");
for(String field:fieldNames){
Object ids = invokeGetMethod(param, field.replaceAll("\\s*", ""));
MuiltiIDS multiId = new MuiltiIDS();
multiId.setIds(ids);
multiId.setTempTableName(field);//临时表名默认字段名
idsList.add(multiId);
}
TempTableContext.multi.set(idsList);
break;
case ENTITY:
TempTableContext.tableName.set(action.tableName());
if(Help.isNull(fieldName)){
throw new RuntimeException("fieldName参数不能为空!!");
}
String firstStr = fieldName.substring(0, 1).toUpperCase();
String substring = fieldName.substring(1, fieldName.length());
String methodName = "get"+firstStr+substring;
Method method2 = class1.getMethod(methodName);
Object invoke = method2.invoke(param);
TempTableContext.ids.set(invoke);
break;
case COLLECTION:
TempTableContext.tableName.set(action.tableName());
TempTableContext.ids.set(param);
break;
default:
break;
}
}
}
/**
* 拦截器具体实现
* @param pjp
* @return JsonResult(被拦截方法的执行结果,或需要登录的错误提示。)
* @throws Throwable
*/
@Around("webPointcut()") //指定拦截器规则;也可以直接把“execution(* com.xjj.........)”写进这里
public Object Interceptor(ProceedingJoinPoint pjp) throws Throwable{
return pjp.proceed();
}
@After("webPointcut()")
public void after() {
//方法执行后删除掉解析出来的数据
TempTableContext.tableName.remove();
TempTableContext.pType.remove();
TempTableContext.ids.remove();
TempTableContext.insertFlag.remove();
TempTableContext.multi.remove();
}
private Object invokeGetMethod(Object object,String fieldName) throws NoSuchMethodException, SecurityException, IllegalAccessException, IllegalArgumentException, InvocationTargetException{
if(object instanceof Map){
Map map = (Map) object;
return map.get(fieldName);
}
Class<? extends Object> class1 = object.getClass();
String firstStr = fieldName.substring(0, 1).toUpperCase();
String substring = fieldName.substring(1, fieldName.length());
String methodName = ("get"+firstStr+substring);
Method method2 = class1.getMethod(methodName);
Object invoke = method2.invoke(object);
return invoke;
}
}
//保存解析自定义注解信息类(mybatis拦截sql的时候需要获取到这里面的信息,进行临时表的新增)
public class TempTableContext {
public static ThreadLocal<String> tableName = new ThreadLocal<String>();
public static ThreadLocal<Object> ids = new ThreadLocal<Object>();
public static ThreadLocal<ParamTypeEM> pType = new ThreadLocal<ParamTypeEM>();
public static ThreadLocal<Boolean> insertFlag = new ThreadLocal<Boolean>();
public static ThreadLocal<List<MuiltiIDS>> multi = new ThreadLocal<List<MuiltiIDS>>();
}
//参数枚举类
public enum ParamTypeEM {
ENTITY(1,"对象类型"),
COLLECTION(2,"集合类型"),
MULTI(3,"多id集合");
private int type;
private String desc;
private ParamTypeEM(int type,String desc){
this.type = type;
this.desc = desc;
}
//省略get set..... 方法
}
3,方法上添加自定义注解,springaop就会拦截到(由于是spring框架的拦截,所以不能把注解写到xxxxmapper上面去,不然拦截不到)
@TempTable(paramType=ParamTypeEM.MULTI,fieldName="relatedIDs",paramIndex=0)
@Override
public int getsomething(dto dto) {
return cdiMapper.getsomething(dto);
}
@TempTable(paramType=ParamTypeEM.ENTITY,fieldName="relatedIDs",paramIndex=0)
@Override
public int getsomething(dto dto) {
return cdiMapper.getsomething(dto);
}
@TempTable(paramType=ParamTypeEM.COLLECTION,fieldName="relatedIDs",paramIndex=0)
@Override
public int getsomething(dto dto) {
return cdiMapper.getsomething(dto);
}
4,实现mybatis的拦截器Interceptor,在sql执行前创建临时表
@Intercepts( {
@Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class, Integer.class}) })
public class MyBatisInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(MyBatisInterceptor.class);
public Object intercept(Invocation invocation) throws Throwable {
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
//通过反射获取到当前RoutingStatementHandler对象的delegate属性
StatementHandler delegate = (StatementHandler)ReflectUtil.getFieldValue(handler, "delegate");
//获取到当前StatementHandler的 boundSql,这里不管是调用handler.getBoundSql()还是直接调用delegate.getBoundSql()结果是一样的,因为之前已经说过了
//RoutingStatementHandler实现的所有StatementHandler接口方法里面都是调用的delegate对应的方法。
BoundSql boundSql = delegate.getBoundSql();
//获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
String sql = boundSql.getSql();
//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
String invokeSql = sql;
ReflectUtil.setFieldValue(boundSql, "sql", invokeSql );
//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate, "mappedStatement");
//拦截到的prepare方法参数是一个Connection对象
Connection connection = (Connection)invocation.getArgs()[0];
if(TempTableContext.pType.get()!=null&&((TempTableContext.insertFlag.get()!=null&&!TempTableContext.insertFlag.get()||TempTableContext.insertFlag.get()==null))){
if((TempTableContext.pType.get().equals(ParamTypeEM.ENTITY)||TempTableContext.pType.get().equals(ParamTypeEM.COLLECTION))&&TempTableContext.tableName.get()!=null){
dropTempTable(mappedStatement, connection,TempTableContext.tableName.get());
createTmpTable(mappedStatement,connection,TempTableContext.tableName.get());
insertTempTable(mappedStatement, connection,TempTableContext.tableName.get(),TempTableContext.ids.get());
TempTableContext.insertFlag.set(true);
}else if(TempTableContext.pType.get().equals(ParamTypeEM.MULTI)){
List<MuiltiIDS> list = TempTableContext.multi.get();
if(Help.isNotNull(list)){
for(MuiltiIDS multi:list){
if(Help.isNotNull(multi.getIds())&&Help.isNotNull(multi.getTempTableName())){
dropTempTable(mappedStatement, connection,multi.getTempTableName());
createTmpTable(mappedStatement,connection,multi.getTempTableName());
insertTempTable(mappedStatement, connection,multi.getTempTableName(),multi.getIds());
}
}
TempTableContext.insertFlag.set(true);
}
}
}
return invocation.proceed();
}
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
public void setProperties(Properties properties) {
}
private void executeSql(
MappedStatement mappedStatement, Connection connection,String sql) {
//通过查询Sql语句获取到对应的计算总记录数的sql语句
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql createTempTableBoundSql = new BoundSql(mappedStatement.getConfiguration(), sql, null, null);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, null, createTempTableBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
try {
pstmt = connection.prepareStatement(sql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
//之后就是执行获取总记录数的Sql语句和获取结果了。
boolean execute = pstmt.execute(sql);
} catch (SQLException e) {
e.printStackTrace();
} finally {
}
}
//删除临时表
private void dropTempTable(MappedStatement mappedStatement, Connection connection,String tempTableName){
Long start = System.currentTimeMillis();
String sql = " DROP TEMPORARY TABLE IF EXISTS "+tempTableName;
executeSql(mappedStatement, connection, sql);
System.out.println("耗时4:"+(System.currentTimeMillis()-start));
}
//创建临时表
private void createTmpTable(MappedStatement mappedStatement, Connection connection,String tempTableName) {
// TODO Auto-generated method stub
StringBuffer sb = new StringBuffer("");
sb.append("CREATE TEMPORARY TABLE IF NOT EXISTS " + tempTableName + "(");
sb.append("id BIGINT(10) NOT NULL");
sb.append(")ENGINE=HEAP;");
String createTempTableSql = sb.toString();
Long start = System.currentTimeMillis();
executeSql(mappedStatement, connection, createTempTableSql);
logger.info("创建临时表耗时:{}",(System.currentTimeMillis()-start));
}
//往临时表中插入数据
private void insertTempTable(MappedStatement mappedStatement, Connection connection,String tempTableName,Object object){
StringBuffer insertSQL = new StringBuffer();
insertSQL.append("INSERT INTO "+tempTableName+" VALUES ");
if(object!=null){
if(object instanceof Collection){
Collection cids = (Collection) object;
Iterator iterator = cids.iterator();
while(iterator.hasNext()){
insertSQL.append("("+iterator.next()+"),");
}
}else{
Long[] i = (Long[]) object;
for(Long l:i){
insertSQL.append("("+l+"),");
}
}
insertSQL.replace(insertSQL.toString().lastIndexOf(","), insertSQL.length(), "");//去掉最后一个逗号
Long start = System.currentTimeMillis();
executeSql(mappedStatement, connection, insertSQL.toString());
logger.info("临时表插入数据耗时:{}",(System.currentTimeMillis()-start));
}
}
}
//引用到的类
/**
* 利用反射进行操作的一个工具类
*
*/
public class ReflectUtil {
/**
* 利用反射获取指定对象的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标属性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return result;
}
/**
* 利用反射获取指定对象里面的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz=obj.getClass(); clazz != Object.class; clazz=clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
//这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
}
}
return field;
}
/**
* 利用反射设置指定对象的指定属性为指定的值
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName,
String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}