总监:喂,小王啊!起来没呢?加个班呗!
我:泥煤啊…
总监:我有个需求啊,这最近导入数据比较多,但是后台用户反映导入了数据,不想要了,删除起来麻烦啊!你也知道,顾客是上帝嘛,给我完成一个导入数据自动一键回滚的功能!
我:说啥也不干,今天休息,我还要打游戏。
总监:那个你申请一个在家办公,两倍工资,我这面批一个。
我:好嘞!
一:需求分析
- 导入一定分为很多种,有商品的,有图片的,有各种业务的,一定要兼容各种具体的业务,那么就不能依赖于具体实现。
- 分析在各个业务层,导入无谓就是处理完数据之后生成的增删改语句。
- 那我只需要处理sql语句就可以了,把增删改的语句生成它具体的相反的语句。insert生成delete语句,udapte生成delete和insert语句,delete生成insert语句。
- 那么多的mapper层接口的语句,怎么知道哪个语句是需要生成相反的语句呢?可以自定义一个注解,然后我们在执行之前看看该接口上面有没有这个注解就行了。
- 那在并行多次导入的时候怎么区分哪些任务是属于同一任务的呢?这么办,运用线程标识该次任务。那开启多线程怎么办呢?开启多线程就把每一个线程都存入任务名称。
- 好啦,差不多思路就是这些,总结一下就是在sql执行之前,拦截要执行的sql。判定该要执行的sql的mapper层接口是否有自定义约定的注解,如果有,那么该语句是需要生成相反sql的。判定该线程中是否存储有任务名称,有则生成相反语句并存储到redis中,该任务名称为redis中的key。value采用list结构,我们从左向右添加,要是执行的话,也是从左边进行执行。
二:代码编写
-
首先自定义一个注解,EnableReverseSql
@Target({ElementType.METHOD, ElementType.PARAMETER}) @Retention(RetentionPolicy.RUNTIME) public @interface EnableReverseSql { }
-
定义一个生成反向sql的顶级接口,以后用于适配不同的数据库
public interface ReverseSqlDb { String getSql(Statement statement); String insertGenerateDelete(Invocation invocation,String sql); List<String> updateGenerateDeleteAndInsert(Invocation invocation,String sql,String className); String deleteGenerateInsert(Invocation invocation,String sql,String className); String getDbVersion(); }
-
因为要生成反向sql,比如删除,只会知道id,那么我们必须要查询该id的所有信息,才能生成insert语句,这里定义一个mapper接口,用于执行在代码中生成的查询语句。在有@ReverseSqlDb注解的接口层必须继承该类。
public interface ReverseMapper { @Select("${sql}") @InterceptorIgnore(tenantLine = "true") LinkedHashMap<String,Object> performSql(@Param("sql") String sql); }
例:
@Mapper public interface TestMapper extends ReverseMapper { }
这样我们就可以执行代码中任意生成的sql语句。
-
核心思想,要拦截sql,就需要实现Interceptor接口,用于拦截需要执行的语句,新建ReverseSqlInterceptor
@Slf4j @Component @Intercepts({ @Signature(type = StatementHandler.class, method = "update", args = Statement.class), @Signature(type = StatementHandler.class, method = "batch", args = Statement.class) }) public class ReverseSqlInterceptor implements Interceptor { @Resource ReverseSqlDbChainOfResponsibility reverseSqlDbChainOfResponsibility; @Autowired SpringConfigProperties springProperties; /** * 获取当前在使用的数据库 * @return 数据库名称 */ private String getDbVersion(){ String druidUrl; //这里只是为了适配不同数据源进行的判断,最终只是要过去正在使用的是什么数据库 if( springProperties.getDataSource().getDruid() == null){ druidUrl = springProperties.getDataSource().getUrl(); }else{ druidUrl = springProperties.getDataSource().getDruid().getUrl(); } return druidUrl.split(":")[1]; } @Override public Object intercept(Invocation invocation) throws Throwable { //获取Statement类对象 Statement statement = this.getStatement(invocation); Object target = PluginUtils.realTarget(invocation.getTarget()); MetaObject metaObject = SystemMetaObject.forObject(target); MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); //获取命名空间 String namespace = mappedStatement.getId(); //获取类名 String className = namespace.substring(0, namespace.lastIndexOf(".")); //获取当前类的方法名 String methodName = namespace.substring(namespace.lastIndexOf(".") + 1); //获取当前类有哪些方法 Method[] ms = Class.forName(className).getMethods(); for (Method m : ms) { if (m.getName().equals(methodName)) { //判断是否有这个注解 Annotation annotation = m.getAnnotation(EnableReverseSql.class); if (annotation != null) { //通过反射redis类实例对象并获取 Method getMethod = InternalThreadLocal.getMethodForNameGet(); HashMap<String, Object> stringObjectHashMap = (HashMap<String, Object>) getMethod.invoke(InternalThreadLocal.treadUtilEntity); String taskName = (String) stringObjectHashMap.get("name"); if(taskName==null){ throw new RuntimeException("未获取到线程中的任务名称,请添加任务名称"); } reverseSqlDbChainOfResponsibility.selectDbAndExecuteChain(getDbVersion(),invocation,statement,className,taskName); } else { return invocation.proceed(); } } } return invocation.proceed(); } /** * ThreadLocal内部静态类 */ private static class InternalThreadLocal{ private static Class<?> treadUtilClass; private static Constructor<?> declaredConstructor; private static Object treadUtilEntity; private static Class[] treadUtilArguments; /** * 加载ThreadUtil工具类 */ static { try { treadUtilClass = Class.forName("com.common.util.ThreadUtil"); declaredConstructor = treadUtilClass.getDeclaredConstructor(); //强制使用私有的构造方法 declaredConstructor.setAccessible(true); treadUtilEntity = declaredConstructor.newInstance(); treadUtilArguments = new Class[0]; } catch (InstantiationException | IllegalAccessException | ClassNotFoundException | NoSuchMethodException | InvocationTargetException e) { log.error("获取ThreadUtil工具类失败"); e.printStackTrace(); } } private static Method getMethodForNameGet() throws NoSuchMethodException { return treadUtilClass.getMethod("get",treadUtilArguments); } } /** * 获取statement */ private Statement getStatement(Invocation invocation) { Statement statement; Object firstArg = invocation.getArgs()[0]; if (Proxy.isProxyClass(firstArg.getClass())) { statement = (Statement) SystemMetaObject.forObject(firstArg).getValue("h.statement"); } else { statement = (Statement) firstArg; } MetaObject stmtMetaObj = SystemMetaObject.forObject(statement); try { statement = (Statement) stmtMetaObj.getValue("stmt.statement"); } catch (Exception e) { //这个位置不需要捕获异常,会报错 } if (stmtMetaObj.hasGetter("delegate")) { try { statement = (Statement) stmtMetaObj.getValue("delegate"); } catch (Exception e) { //这个位置不需要捕获异常,会报错 } } if(statement != null){ return statement; }else{ throw new RuntimeException("未获取到Statement类"); } } }
-
为了以后适配更多的数据库,新建ReverseSqlDbChainOfResponsibility类,用于适配不同的数据库
@Component public class ReverseSqlDbChainOfResponsibility implements CommandLineRunner, ApplicationContextAware { private Collection<ReverseSqlDb> reverseSqlDbList; private volatile ApplicationContext applicationContext; @Override public void run(String... args) throws Exception { init(); } private void init() { reverseSqlDbList = new LinkedList<>(this.applicationContext.getBeansOfType(ReverseSqlDb.class).values()); } @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext=applicationContext; } @SneakyThrows void selectDbAndExecuteChain(String dbVersion, Invocation invocation, Statement statement, String className, String taskName){ //反射获取列表的又添加 Method lSetMethod = RedisMethod.getLrSetObj(); //获取列表的右添加列表 Method lSetListMethod = RedisMethod.getLrSetList(); for (ReverseSqlDb reverseSqlDb:reverseSqlDbList ) { if(reverseSqlDb instanceof Proxy){ continue; } if(dbVersion.equals(reverseSqlDb.getDbVersion())){ //判断sql是要执行增删改中的哪一个方法 String sql = reverseSqlDb.getSql(statement); if (sql.contains("INSERT") || sql.contains("insert")) { //调用新增方法生成反向sql String reverseSql = reverseSqlDb.insertGenerateDelete(invocation, sql); lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql); } else if (sql.contains("UPDATE") || sql.contains("update")) { //调用修改方法生成反向sql List<String> reverseSqlList = reverseSqlDb.updateGenerateDeleteAndInsert(invocation, sql, className); lSetListMethod.invoke(RedisMethod.redisEntity,taskName,reverseSqlList); } else if (sql.contains("DELETE") || sql.contains("delete")) { //调用删除方法生成反向sql String reverseSql = reverseSqlDb.deleteGenerateInsert(invocation, sql, className); lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql); } } } } }
-
编写具体的实现类ReverseSqlDbPg
@Slf4j @Component public class ReverseSqlDbPg extends AbstractBusiness implements ReverseSqlDb { private static final String DRUID_POOLED_PREPARED_STATEMENT = "com.alibaba.druid.pool.DruidPooledPreparedStatement"; private static final String T4C_PREPARED_STATEMENT = "oracle.jdbc.driver.T4CPreparedStatement"; private static final String ORACLE_PREPARED_STATEMENT_WRAPPER = "oracle.jdbc.driver.OraclePreparedStatementWrapper"; private Method oracleGetOriginalSqlMethod; private Method druidGetSqlMethod; static final String DB_VERSION = "postgresql"; /** * 获取当前正在执行的sql * * @param statement 声明 * @return 当前要执行的语句 */ @Override public String getSql(Statement statement) { String originalSql = null; String stmtClassName = statement.getClass().getName(); if (DRUID_POOLED_PREPARED_STATEMENT.equals(stmtClassName)) { try { if (druidGetSqlMethod == null) { Class<?> clazz = Class.forName(DRUID_POOLED_PREPARED_STATEMENT); druidGetSqlMethod = clazz.getMethod("getSql"); } Object stmtSql = druidGetSqlMethod.invoke(statement); if (stmtSql instanceof String) { originalSql = (String) stmtSql; } } catch (Exception e) { e.printStackTrace(); } } else if (T4C_PREPARED_STATEMENT.equals(stmtClassName) || ORACLE_PREPARED_STATEMENT_WRAPPER.equals(stmtClassName)) { try { if (oracleGetOriginalSqlMethod != null) { Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement); if (stmtSql instanceof String) { originalSql = (String) stmtSql; } } else { Class<?> clazz = Class.forName(stmtClassName); oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql"); if (oracleGetOriginalSqlMethod != null) { oracleGetOriginalSqlMethod.setAccessible(true); if (null != oracleGetOriginalSqlMethod) { Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement); if (stmtSql instanceof String) { originalSql = (String) stmtSql; } } } } } catch (Exception e) { //ignore } } if (originalSql == null) { originalSql = statement.toString(); } originalSql = originalSql.replaceAll("[\\s]+", StringPool.SPACE); int index = indexOfSqlStart(originalSql); if (index > 0) { originalSql = originalSql.substring(index); } return originalSql; } /** * 新增生成删除 * * @param invocation * @param sql 要执行的sql * @return 生成的反向sql */ @Override public String insertGenerateDelete(Invocation invocation, String sql) { List<String> paramTerList = this.getParamTerList(invocation); //添加的参数列表第一位是id,我们就默认第一位是id,添加的话只需要反向生成删除的sql即可 //获取要删除的表名,添加语句的insert into 表名,所以这里取列表中的第三位 String[] words = sql.split(" "); String tableName = words[2]; //拼接反向sql StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append("delete from "); stringBuilder.append(tableName); stringBuilder.append(" where id = "); stringBuilder.append(paramTerList.get(0)); return stringBuilder.toString(); } /** * 修改方法生成删除和新增方法实现 */ @Override public List<String> updateGenerateDeleteAndInsert(Invocation invocation, String sql, String className) { ArrayList<String> resList = new ArrayList<>(); List<String> paramTerList = this.getParamTerList(invocation); //修改的语句最后的参数为id,默认最后的条件为id.表明则为单词的第二个单词,由此获得id与表名 String[] words = sql.split(" "); String tableName = words[1]; String id = paramTerList.get(paramTerList.size() - 1); //生成删除语句 StringBuffer deleteBuffer = new StringBuffer(); deleteBuffer.append("delete from "); deleteBuffer.append(tableName); deleteBuffer.append(" where id = "); deleteBuffer.append(id); //修改我们需要查询该id的所有数据,这里通过反射注入该接口,并通过继承的方式必须实现我们规定的接口,从而执行拼接的查询sql LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id); //获取所有的value Set<String> keys = resMap.keySet(); //拼接新增语句 StringBuffer insertBuffer = new StringBuffer(); insertBuffer.append("insert into "); insertBuffer.append(tableName); insertBuffer.append(" values ("); for (String key : keys ) { if (resMap.get(key) != null) { insertBuffer.append("'"); insertBuffer.append(resMap.get(key)); insertBuffer.append("'"); insertBuffer.append(","); } else { insertBuffer.append(resMap.get(key)); insertBuffer.append(","); } } insertBuffer.deleteCharAt(insertBuffer.length() - 1); insertBuffer.append(")"); //结果添加到列表 resList.add(deleteBuffer.toString()); resList.add(insertBuffer.toString()); log.info(resMap.toString()); return resList; } /** * 删除语句生成新增的具体执行方法 */ @Override public String deleteGenerateInsert(Invocation invocation, String sql, String className) { List<String> paramTerList = this.getParamTerList(invocation); //修改的语句最后的参数为id,默认最后的条件为id.表名则为单词的第三个单词,由此获得id与表名 String[] words = sql.split(" "); String tableName = words[2]; String id = paramTerList.get(paramTerList.size()-1); LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id); //获取所有的value Set<String> keys = resMap.keySet(); //拼接新增语句 StringBuffer insertBuffer = new StringBuffer(); insertBuffer.append("insert into "); insertBuffer.append(tableName); insertBuffer.append(" values ("); for (String key:keys ) { if(resMap.get(key)!=null){ insertBuffer.append("'"); insertBuffer.append(resMap.get(key)); insertBuffer.append("'"); insertBuffer.append(","); }else{ insertBuffer.append(resMap.get(key)); insertBuffer.append(","); } } insertBuffer.deleteCharAt(insertBuffer.length()-1); insertBuffer.append(")"); insertBuffer.append(" on CONFLICT(id) do NOTHING "); return insertBuffer.toString(); } /** * 获取当前执行器是哪个执行器 * @return 执行器名称 */ @Override public String getDbVersion() { return DB_VERSION; } /** * 通过反射获取接口并执行继承的方法 */ private LinkedHashMap<String, Object> getMapById(String className, String tableName, String id) { Class<? extends ReverseMapper> serviceClass; try { serviceClass = (Class<? extends ReverseMapper>) Class.forName(className); } catch (Exception e) { throw new RuntimeException("如使用**注解,请继承ReverseMapper接口"); } //生成查询语句 StringBuffer selectBuffer = new StringBuffer(); selectBuffer.append("select * from "); selectBuffer.append(tableName); selectBuffer.append(" where id = "); selectBuffer.append(id); //反射调用规定好的方法 return super.getMapper(serviceClass).performSql(selectBuffer.toString()); } /** * 获取该语句的参数列表 */ private List<String> getParamTerList(Invocation invocation) { Object target = PluginUtils.realTarget(invocation.getTarget()); MetaObject metaObject = SystemMetaObject.forObject(target); // 参数 BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql"); Object parameterObject = boundSql.getParameterObject(); List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings()); if (parameterMappings.isEmpty() && parameterObject == null) { log.warn("parameterMappings is empty or parameterObject is null"); } MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); Configuration configuration = mappedStatement.getConfiguration(); TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry(); List<String> parameterList = new ArrayList<>(); MetaObject newMetaObject = configuration.newMetaObject(parameterObject); for (ParameterMapping parameterMapping : parameterMappings) { String parameter = null; if (parameterMapping.getMode() == ParameterMode.OUT) { continue; } String propertyName = parameterMapping.getProperty(); if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) { parameter = getParameterValue(parameterObject); } else if (newMetaObject.hasGetter(propertyName)) { parameter = getParameterValue(newMetaObject.getValue(propertyName)); } else if (boundSql.hasAdditionalParameter(propertyName)) { parameter = getParameterValue(boundSql.getAdditionalParameter(propertyName)); } parameterList.add(parameter); } return parameterList; } /** * 获取参数 * * @param param Object类型参数 * @return 转换之后的参数 */ private static String getParameterValue(Object param) { if (param == null) { return "null"; } if (param instanceof Number) { return param.toString(); } String value = param.toString(); return StringUtils.quotaMark(value); } /** * 获取此方法名的具体 Method * * @param clazz class 对象 * @param methodName 方法名 * @return 方法 */ private Method getMethodRegular(Class<?> clazz, String methodName) { if (Object.class.equals(clazz)) { return null; } for (Method method : clazz.getDeclaredMethods()) { if (method.getName().equals(methodName)) { return method; } } return getMethodRegular(clazz.getSuperclass(), methodName); } /** * 获取sql语句开头部分 * * @param sql ignore * @return ignore */ private int indexOfSqlStart(String sql) { String upperCaseSql = sql.toUpperCase(); Set<Integer> set = new HashSet<>(); set.add(upperCaseSql.indexOf("SELECT ")); set.add(upperCaseSql.indexOf("UPDATE ")); set.add(upperCaseSql.indexOf("INSERT ")); set.add(upperCaseSql.indexOf("DELETE ")); set.remove(-1); if (CollectionUtils.isEmpty(set)) { return -1; } List<Integer> list = new ArrayList<>(set); list.sort(Comparator.naturalOrder()); return list.get(0); } }
-
这里在获取mapper接口中个各个方法的时候,为了防止具体使用者没有继承ReverseMapper接口,写了一个AbstractBusiness类抽象类。代码如下
public abstract class AbstractBusiness { @Resource BeanHelper helper; public void setMapper(BeanHelper helper) { this.helper = helper; } public BeanHelper getMapper() { return helper; } public final <T extends ReverseMapper> T getMapper(Class<T> t){ return helper.getBean(t); } }
-
BeanHelper代码
@Component public class BeanHelper implements ApplicationContextAware { public static ApplicationContext applicationContext; @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { BeanHelper.applicationContext = applicationContext; } public <T extends ReverseMapper> T getBean(Class<T> t) { return applicationContext.getBean(t); } public Object getSpringBean(String s) { return applicationContext.getBean(s); } }
-
-
其他帮助类
@Slf4j public class RedisMethod { private static Class<?> redisClass; public static Object redisEntity; static { try { redisClass = Class.forName("com.component.redis.Redis"); if(redisClass == null){ throw new RuntimeException("获取不到redis工具类,请引入包后重试"); } redisEntity = redisClass.newInstance(); } catch (IllegalAccessException | InstantiationException | ClassNotFoundException e) { log.error("获取Redis工具类失败"); e.printStackTrace(); } } /** * 获取redis工具类右添加obj的方法 * @return 方法 */ @SneakyThrows public static Method getLrSetObj(){ Class[] rightPushArguments = new Class[2]; rightPushArguments[0] = String.class; rightPushArguments[1] = Object.class; return redisClass.getMethod("lrSetObj",rightPushArguments); } /** * 获取又添加列表的方法 * @return 方法 */ @SneakyThrows public static Method getLrSetList(){ Class[] rightPushListArguments = new Class[2]; rightPushListArguments[0] = String.class; rightPushListArguments[1] = List.class; return redisClass.getMethod("lrSetList",rightPushListArguments); } }
工具类在这里,redis工具类
@Slf4j
public class ThreadUtil {
private ThreadUtil(){ }
private static final ThreadLocal threadLocal = new ThreadLocal<>();
public static void set(Object o) {
threadLocal.set(o);
}
public static Object get() {
return threadLocal.get();
}
public static void remove(){
threadLocal.remove();
}
}
三:代码执行
-
public class TestController { @Resource private TestMapper testMapper; public void test(){ //此处放入线程的为任务名称,根据业务自行调整,多线程时也需要放入线程中该任务名称 ThreadUtil.set("test"); testMapper.insert(); ThreadUtil.remove(); } }
-
最后线程执行结束的时候,不要忘了调用ThreadUtil.remove()方法删除,删除线程中数据。
-
思路就是这样,代码还有优化的空间。可自行修改。
四:注意事项
- 当前代码只支持一条条插入,一条条的删除,一条一条的修改,而且只能是最基本的增删改,要使用此功能的话需要编写对应接口的标准的增删改语句,标准语句的格式在代码中有所描述。
- 数据在更新之前都会进行一次查询,速度会响应的减慢。
- 生成的反向sql是存储在redis中的,可以把redis中的key存储在数据库,(反向的sql已经有了,怎么用,怎么执行看自己的业务),redis中的数据可以在下一次导入之前进行删除。
我:喂,总监啊,搞定了啊!
总监:不错,小伙子。