BatchOperationThreadTaskUtil 主要工具类
@Component
public class BatchOperationThreadTaskUtil {
private static final Logger log = LoggerFactory.getLogger(BatchOperationThreadTaskUtil.class);
/**
* 同业务最大线程上限,防止超过连接数
*/
private static final int DEFAULTMAXPOOLSIZE = 20;
/**
* 业务线程池
*/
private static final Map<String, ThreadPoolExecutor> threadPoolsMap = new ConcurrentHashMap<>(8);
private static PlatformTransactionManager transactionManager;
@Autowired
private PlatformTransactionManager remoteTransactionManager;
@PostConstruct
public void init() {
this.transactionManager = this.remoteTransactionManager;
}
@PreDestroy
public void destroy() {
for (Map.Entry<String, ThreadPoolExecutor> entry:
threadPoolsMap.entrySet()) {
entry.getValue().shutdown();
}
}
/**
* 初始化自定义业务上限
* @param businessCode 业务编码
* @param maxPoolSize 上限
*/
public static synchronized void initThreadPool(String businessCode,int maxPoolSize){
if(threadPoolsMap.get(businessCode) == null){
ThreadPoolExecutor threadPoolExecutor = newThreadPoolExecutor(maxPoolSize);
threadPoolsMap.put(businessCode,threadPoolExecutor);
}
}
/**
* 根据业务代码获取线程池
* @param businessCode 业务代码
* @return
*/
public static ThreadPoolExecutor getThreadPoolExecutor(String businessCode){
ThreadPoolExecutor threadPoolExecutor = threadPoolsMap.get(businessCode);
if(threadPoolExecutor == null){
threadPoolExecutor = newThreadPoolExecutor();
synchronized (threadPoolsMap){
if(threadPoolsMap.get(businessCode) == null){
threadPoolsMap.put(businessCode,threadPoolExecutor);
}else{
threadPoolExecutor = threadPoolsMap.get(businessCode);
}
}
}
return threadPoolExecutor;
}
/**
* 新建线程池
* @return
*/
private static ThreadPoolExecutor newThreadPoolExecutor(){
return newThreadPoolExecutor(DEFAULTMAXPOOLSIZE);
}
/**
* 新建线程池
* @param maxPoolSize
* @return
*/
private static ThreadPoolExecutor newThreadPoolExecutor(int maxPoolSize){
TaskQueue taskQueue = new TaskQueue(300);
ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(0,maxPoolSize,1, TimeUnit.MILLISECONDS, taskQueue);
taskQueue.setParent(threadPoolExecutor);
return threadPoolExecutor;
}
/**
* 移除业务线程池
* @param businessCode 业务代码
*/
private static void removeThreadPool(String businessCode){
ThreadPoolExecutor threadPoolExecutor = threadPoolsMap.get(businessCode);
if (threadPoolExecutor != null){
threadPoolExecutor.shutdown();
threadPoolExecutor = null;
}
threadPoolsMap.remove(businessCode);
}
/**
* 无上限同步多线程执行,同事务,可一起回滚
* 慎用,易死锁,易超过连接池上限
* @param executeBeans
* @param callable
*/
public static <T> void excute(Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
excute(executeBeans,callable,(i)->{});
}
/**
* 无上限同步多线程执行,同事务,可一起回滚
* 慎用,易死锁,易超过连接池上限
* @param executeBeans 遍历对象
* @param callable 多线程执行方法体
* @param middleFunction 线程之间执行方法体,例如相隔100ms执行,防止id重复
* @param <T>
* @throws Exception
*/
public static <T> void excute(Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
CountDownLatch mainLatch = new CountDownLatch(1);
CountDownLatch sampleLatch = new CountDownLatch(executeBeans.size());
AtomicBoolean rollBackFlag = new AtomicBoolean(false);
for (T b: executeBeans) {
new Thread(new Runnable() {
@Override
public void run() {
if(rollBackFlag.get()){
sampleLatch.countDown();
return;
}
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
TransactionStatus status = transactionManager.getTransaction(def);
try{
callable.call(b);
sampleLatch.countDown();
mainLatch.await();
if(rollBackFlag.get()){
transactionManager.rollback(status);
}else{
transactionManager.commit(status);
}
}catch (Exception e){
log.error("批量操作失败",e);
rollBackFlag.set(true);
sampleLatch.countDown();
transactionManager.rollback(status);
}
}
}).start();
middleFunction.call(b);
}
sampleLatch.await();
mainLatch.countDown();
if(rollBackFlag.get()){
throw new RuntimeException("批量操作失败");
}
}
/**
* 有上限多线程异步运行
* @param businessCode 同个业务公用上限
* @param executeBeans 遍历对象
* @param callable 线程执行方法体
* @param <T>
* @throws Exception
*/
public static <T> void asyncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
asyncThreadInPool(businessCode,executeBeans,callable,(i)->{});
}
/**
* 有上限多线程异步运行
* @param businessCode 同个业务公用上限
* @param executeBeans 遍历对象
* @param callable 线程执行方法体
* @param middleFunction 多线程间方法体。例如线程之间间隔100ms。当超过上线后的线程间可能失效
* @param <T>
* @throws Exception
*/
public static <T> void asyncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
ThreadPoolExecutor threadPoolExecutor = getThreadPoolExecutor(businessCode);
for (T b: executeBeans) {
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
try {
callable.call(b);
} catch (Exception e) {
log.info("批量操作线程异常!批处理业务代码:【"+businessCode+"】;批处理对象:【"+b.toString()+"】",e);
}
}
});
middleFunction.call(b);
}
}
/**
* 有上限多线程同步运行
* @param businessCode 同个业务公用上限
* @param executeBeans 遍历对象
* @param callable 线程执行方法体
* @param <T>
* @throws Exception
*/
public static <T> void syncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
syncThreadInPool(businessCode,executeBeans,callable,(i)->{});
}
/**
* 有上限多线程同步运行
* @param businessCode 同个业务公用上限
* @param executeBeans 遍历对象
* @param callable 线程执行方法体
* @param middleFunction 多线程间方法体。例如线程之间间隔100ms。当超过上线后的线程间可能失效
* @param <T>
* @throws Exception
*/
public static <T> void syncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
CountDownLatch threadLatch = new CountDownLatch(executeBeans.size());
ThreadPoolExecutor threadPoolExecutor = getThreadPoolExecutor(businessCode);
for (T b: executeBeans) {
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
try {
callable.call(b);
} catch (Exception e) {
log.info("批量操作线程异常!批处理业务代码:【"+businessCode+"】;批处理对象:【"+b.toString()+"】",e);
}finally {
threadLatch.countDown();
}
}
});
middleFunction.call(b);
}
threadLatch.await();
}
}
自定义函数式接口CallableWithException
@FunctionalInterface
public interface CallableWithException<T>{
void call(T e) throws Exception;
}
使用
1、可在业务类初始化业务线程池上限
static {
BatchOperationThreadTaskUtil.initThreadPool("entrustBatchWorkFlow",50);
}
2、使用例子
BatchOperationThreadTaskUtil.excute(sampleIds,(sampleId) ->{
//业务代码
});
3、只用线程池
String reportBatchWorkFlow = "reportBatchWorkFlow";
String pendingReportCode = "pendingReportCode";
String reportBatchSet = reportBatchWorkFlow + "_set-" + userId;
//校验
if(RedisLockUtils.tryLock(reportBatchWorkFlow)){
try{
//当前任务校验
if(RedisUtils.hasKey(reportBatchSet) && CollectionUtils.isNotEmpty(RedisUtils.getSets(reportBatchSet))){
return new ResponseData(1,"当前有任务正在处理中");
}
//处理中任务校验
Set existReportCodes = RedisUtils.setIntersect(pendingReportCode,reportCodes);
if(CollectionUtils.isNotEmpty(existReportCodes)){
return new ResponseData(3,"报告["+String.join(",",existReportCodes)+"]正在处理中,请稍后再试");
}
RedisUtils.setsAdd(reportBatchSet,reportCodes);
RedisUtils.setsAdd(pendingReportCode,reportCodes);
}finally {
RedisLockUtils.releaseLock(reportBatchWorkFlow);
}
}else{
return new ResponseData(4,"获取锁超时,请稍后再试");
}
//获取线程池
ThreadPoolExecutor threadPoolExecutor = BatchOperationThreadTaskUtil.getThreadPoolExecutor("reportBatchWorkFlow");=
for (String reportCode: reportCodes) {
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
try{
//业务代码
}catch (Exception e){
logger.error("批处理报告出错",e);
}finally {
RedisUtils.setsRemove(reportBatchSet,reportCode);
RedisUtils.setsRemove(pendingReportCode,reportCode);
}
}
});
}
注意:企业应用还是慎用多线程。
遇到问题:
1、异步多线程中,业务代码中用到Request,获取url地址信息。当主线程中request返回时,已经无法使用原来的request,导致系统报错。
2、业务代码中用到获取当前登录用户信息。故有几个问题(1)线程一直存在,故获取当前用户是第一个开启线程的用户,导致后面获取当前线程时,获取当前登录用户不正确。(2)若线程存活时间超过登录超时时间,则超时之后获取不到当前用户。(3)若开启线程之后该用户退出,后面用到该线程会提示会话已经失效。(4)若采用一条失败,全部回滚,则需要小心,批处理数量一定要小于上限,否则会一直等待,所以上面工具类并没有“有上限同步回滚”方法。(5)若采用无上限多线程,容易超过数据库最大连接数。
附件:RedisLockUtils
@Component
public class RedisLockUtils {
private static RedisTemplate redisTemplate;
public static final String LOCKPREFIX = "LOCK_";
public static final int DEFAULT_RETRY_COUNT = 10;
public static final long DEFAULT_RETRY_TIME = 100;
@Autowired
private JedisConnectionFactory jedisConnectionFactory;
@PostConstruct
public void init() {
this.redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(jedisConnectionFactory);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
redisTemplate.setDefaultSerializer(stringRedisSerializer);
redisTemplate.afterPropertiesSet();
redisTemplate.delete(redisTemplate.keys(LOCKPREFIX+"*"));
}
/**
* 获取锁
* @param key 业务代码
* @return
* @throws InterruptedException
*/
public static Boolean tryLock(String key) throws InterruptedException {
return tryLock(key,DEFAULT_RETRY_COUNT,DEFAULT_RETRY_TIME);
}
/**
* 获取锁
* @param key 业务代码
* @param count 重试次数
* @return
* @throws InterruptedException
*/
public static Boolean tryLock(String key,int count,long retryTime) throws InterruptedException {
Boolean lock = Boolean.FALSE;
do{
if(count < 0){
return false;
}
lock = redisTemplate.opsForValue().setIfAbsent(LOCKPREFIX+key, Thread.currentThread().getName());
Thread.sleep(retryTime);
count--;
}while (!lock);
return lock;
}
public static void releaseLock(String key){
redisTemplate.delete(LOCKPREFIX+key);
}
}
RedisUtils
@Component
public class RedisUtils {
private static RedisTemplate redisTemplate;
@Autowired
private JedisConnectionFactory jedisConnectionFactory;
@PostConstruct
public void init() {
this.redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(jedisConnectionFactory);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
redisTemplate.setDefaultSerializer(stringRedisSerializer);
redisTemplate.afterPropertiesSet();
}
public static RedisTemplate getRedisTemplate(){
return redisTemplate;
}
/**
* 删除key
* @param key
*/
public static void delKey(Object key){
redisTemplate.delete(key);
}
/**
* 删除key
* @param key
*/
public static Boolean expire(String key, final long timeout, final TimeUnit unit) {
return redisTemplate.expire(key,timeout,unit);
}
/**
* 设置字符串key value
* @param key
* @param value
*/
public static void setKey(String key, String value){
redisTemplate.opsForValue().set(key, value);
}
/**
* 设置过期时间的key value
* @param key
* @param value
* @param time
* @param timeUnit
*/
public static void setKey(String key, String value, long time, TimeUnit timeUnit){
redisTemplate.opsForValue().set(key, value, time, timeUnit);
}
/**
* 获取key的值
* @param key
* @return
*/
public static Object getKey(String key){
return redisTemplate.opsForValue().get(key);
}
/**
* 若不存在则设置,若存在返回false
* @param key
* @param value
* @return
*/
public static Boolean setNx(String key,String value){
return redisTemplate.opsForValue().setIfAbsent(key, value);
}
/**
* key值累加
* @param key
* @param delta
* @return
*/
public static Long increment(String key, long delta){
return redisTemplate.opsForValue().increment(key, delta);
}
/**
* set 增加
* @param key
* @param value
*/
public static void setsAdd(String key, Object... value){
redisTemplate.opsForSet().add(key,value);
}
/**
* set 删除
* @param key
* @param value
*/
public static void setsRemove(String key, String value){
redisTemplate.opsForSet().remove(key,value);
}
/**
* set 长度
* @param key
* @return
*/
public static long setsCount(String key){
return redisTemplate.opsForSet().size(key);
}
/**
* set 内容
* @param key
* @return
*/
public static Set getSets(String key){
return redisTemplate.opsForSet().members(key);
}
/**
* set 交集
* @param key
* @param otherKey
* @return
*/
public static Set setIntersect(String key, String otherKey){
return redisTemplate.opsForSet().intersect(key,otherKey);
}
/**
* set 交集
* @param key redis中的set
* @param list 要比较的集合
* @return
*/
public static Set setIntersect(String key, Collection list){
return setIntersect(key,list.toArray());
}
/**
* set 交集
* @param key redis中的set
* @param arrays 要比较的集合
* @return
*/
public static Set setIntersect(String key, Object[] arrays){
String tempKey = "TMP_CMP_SET-"+ UUID.randomUUID().toString();
RedisUtils.setsAdd(tempKey,arrays);
Set intersectSet = RedisUtils.setIntersect(key,tempKey);
delKey(tempKey);
return intersectSet;
}
public static Boolean hasKey(String key){
return redisTemplate.hasKey(key);
}
}