自用线程池工具类

BatchOperationThreadTaskUtil 主要工具类


@Component
public class BatchOperationThreadTaskUtil {
    private static final Logger log = LoggerFactory.getLogger(BatchOperationThreadTaskUtil.class);

    /**
     * 同业务最大线程上限,防止超过连接数
     */
    private static final int DEFAULTMAXPOOLSIZE = 20;

    /**
     * 业务线程池
     */
    private static final Map<String, ThreadPoolExecutor> threadPoolsMap = new ConcurrentHashMap<>(8);

    private static PlatformTransactionManager transactionManager;
    @Autowired
    private PlatformTransactionManager remoteTransactionManager;
    @PostConstruct
     public void init() {
        this.transactionManager = this.remoteTransactionManager;
     }
    @PreDestroy
     public void destroy() {
        for (Map.Entry<String, ThreadPoolExecutor> entry:
             threadPoolsMap.entrySet()) {
            entry.getValue().shutdown();
        }
     }

    /**
     * 初始化自定义业务上限
     * @param businessCode      业务编码
     * @param maxPoolSize       上限
     */
     public static synchronized void initThreadPool(String businessCode,int maxPoolSize){
       if(threadPoolsMap.get(businessCode) == null){
           ThreadPoolExecutor threadPoolExecutor = newThreadPoolExecutor(maxPoolSize);
           threadPoolsMap.put(businessCode,threadPoolExecutor);
       }
     }

    /**
     * 根据业务代码获取线程池
     * @param businessCode  业务代码
     * @return
     */
    public static ThreadPoolExecutor getThreadPoolExecutor(String businessCode){
        ThreadPoolExecutor threadPoolExecutor = threadPoolsMap.get(businessCode);
        if(threadPoolExecutor == null){
            threadPoolExecutor = newThreadPoolExecutor();
            synchronized (threadPoolsMap){
                if(threadPoolsMap.get(businessCode) == null){
                    threadPoolsMap.put(businessCode,threadPoolExecutor);
                }else{
                    threadPoolExecutor = threadPoolsMap.get(businessCode);
                }
            }

        }
        return  threadPoolExecutor;
    }

    /**
     * 新建线程池
     * @return
     */
    private static ThreadPoolExecutor newThreadPoolExecutor(){
        return newThreadPoolExecutor(DEFAULTMAXPOOLSIZE);
    }

    /**
     * 新建线程池
     * @param maxPoolSize
     * @return
     */
    private static ThreadPoolExecutor newThreadPoolExecutor(int maxPoolSize){
        TaskQueue taskQueue = new TaskQueue(300);
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(0,maxPoolSize,1, TimeUnit.MILLISECONDS, taskQueue);
        taskQueue.setParent(threadPoolExecutor);
        return threadPoolExecutor;
    }

    /**
     * 移除业务线程池
     * @param businessCode      业务代码
     */
    private static void removeThreadPool(String businessCode){
        ThreadPoolExecutor threadPoolExecutor = threadPoolsMap.get(businessCode);
        if (threadPoolExecutor != null){
            threadPoolExecutor.shutdown();
            threadPoolExecutor = null;
        }
        threadPoolsMap.remove(businessCode);
    }

    /**
     * 无上限同步多线程执行,同事务,可一起回滚
     * 慎用,易死锁,易超过连接池上限
     * @param executeBeans
     * @param callable
     */
    public static <T> void excute(Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
       excute(executeBeans,callable,(i)->{});
    }

    /**
     * 无上限同步多线程执行,同事务,可一起回滚
     * 慎用,易死锁,易超过连接池上限
     * @param executeBeans      遍历对象
     * @param callable          多线程执行方法体
     * @param middleFunction    线程之间执行方法体,例如相隔100ms执行,防止id重复
     * @param <T>
     * @throws Exception
     */
    public static <T> void excute(Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
        CountDownLatch mainLatch = new CountDownLatch(1);
        CountDownLatch sampleLatch = new CountDownLatch(executeBeans.size());
        AtomicBoolean rollBackFlag = new AtomicBoolean(false);
        for (T b: executeBeans) {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    if(rollBackFlag.get()){
                        sampleLatch.countDown();
                        return;
                    }

                    DefaultTransactionDefinition def = new DefaultTransactionDefinition();
                    def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
                    TransactionStatus status = transactionManager.getTransaction(def);

                    try{
                       
                        callable.call(b);

                        sampleLatch.countDown();
                        mainLatch.await();
                        if(rollBackFlag.get()){
                            transactionManager.rollback(status);
                        }else{
                            transactionManager.commit(status);
                        }
                    }catch (Exception e){
                        log.error("批量操作失败",e);
                        rollBackFlag.set(true);
                        sampleLatch.countDown();
                        transactionManager.rollback(status);
                    }
                }
            }).start();
            middleFunction.call(b);
        }


        sampleLatch.await();
        mainLatch.countDown();
        if(rollBackFlag.get()){
            throw new RuntimeException("批量操作失败");
        }
    }

    /**
     * 有上限多线程异步运行
     * @param businessCode  同个业务公用上限
     * @param executeBeans  遍历对象
     * @param callable      线程执行方法体
     * @param <T>
     * @throws Exception
     */
    public static <T> void asyncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
        asyncThreadInPool(businessCode,executeBeans,callable,(i)->{});
    }

    /**
     * 有上限多线程异步运行
     * @param businessCode  同个业务公用上限
     * @param executeBeans  遍历对象
     * @param callable      线程执行方法体
     * @param middleFunction    多线程间方法体。例如线程之间间隔100ms。当超过上线后的线程间可能失效
     * @param <T>
     * @throws Exception
     */
    public static <T> void asyncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
        ThreadPoolExecutor threadPoolExecutor = getThreadPoolExecutor(businessCode);
        for (T b: executeBeans) {
            threadPoolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        callable.call(b);
                    } catch (Exception e) {
                        log.info("批量操作线程异常!批处理业务代码:【"+businessCode+"】;批处理对象:【"+b.toString()+"】",e);
                    }
                }
            });
            middleFunction.call(b);
        }
    }

    /**
     * 有上限多线程同步运行
     * @param businessCode  同个业务公用上限
     * @param executeBeans  遍历对象
     * @param callable      线程执行方法体
     * @param <T>
     * @throws Exception
     */
    public static <T> void syncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable) throws Exception{
        syncThreadInPool(businessCode,executeBeans,callable,(i)->{});
    }

    /**
     * 有上限多线程同步运行
     * @param businessCode  同个业务公用上限
     * @param executeBeans  遍历对象
     * @param callable      线程执行方法体
     * @param middleFunction    多线程间方法体。例如线程之间间隔100ms。当超过上线后的线程间可能失效
     * @param <T>
     * @throws Exception
     */
    public static <T> void syncThreadInPool(String businessCode, Collection<T> executeBeans, CallableWithException<T> callable, CallableWithException middleFunction) throws Exception{
        CountDownLatch threadLatch = new CountDownLatch(executeBeans.size());
        ThreadPoolExecutor threadPoolExecutor = getThreadPoolExecutor(businessCode);
        for (T b: executeBeans) {
            threadPoolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        callable.call(b);
                    } catch (Exception e) {
                        log.info("批量操作线程异常!批处理业务代码:【"+businessCode+"】;批处理对象:【"+b.toString()+"】",e);
                    }finally {
                        threadLatch.countDown();
                    }
                }
            });
            middleFunction.call(b);
        }

        threadLatch.await();
    }


}

自定义函数式接口CallableWithException

@FunctionalInterface
public interface CallableWithException<T>{
    void call(T e) throws Exception;
}

使用
1、可在业务类初始化业务线程池上限


    static {
        BatchOperationThreadTaskUtil.initThreadPool("entrustBatchWorkFlow",50);
    }

2、使用例子

 BatchOperationThreadTaskUtil.excute(sampleIds,(sampleId) ->{
           //业务代码
 });

3、只用线程池

 		String reportBatchWorkFlow = "reportBatchWorkFlow";
        String pendingReportCode = "pendingReportCode";
        String reportBatchSet = reportBatchWorkFlow + "_set-" + userId;

        //校验
        if(RedisLockUtils.tryLock(reportBatchWorkFlow)){
            try{
                //当前任务校验
                if(RedisUtils.hasKey(reportBatchSet) && CollectionUtils.isNotEmpty(RedisUtils.getSets(reportBatchSet))){
                    return new ResponseData(1,"当前有任务正在处理中");
                }

                //处理中任务校验
                Set existReportCodes = RedisUtils.setIntersect(pendingReportCode,reportCodes);
                if(CollectionUtils.isNotEmpty(existReportCodes)){
                    return new ResponseData(3,"报告["+String.join(",",existReportCodes)+"]正在处理中,请稍后再试");
                }

                RedisUtils.setsAdd(reportBatchSet,reportCodes);
                RedisUtils.setsAdd(pendingReportCode,reportCodes);

            }finally {
                RedisLockUtils.releaseLock(reportBatchWorkFlow);
            }
        }else{
            return new ResponseData(4,"获取锁超时,请稍后再试");
        }

        //获取线程池
        ThreadPoolExecutor threadPoolExecutor = BatchOperationThreadTaskUtil.getThreadPoolExecutor("reportBatchWorkFlow");=
        for (String reportCode: reportCodes) {
            threadPoolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                   try{
                       //业务代码
                   }catch (Exception e){
                        logger.error("批处理报告出错",e);
                   }finally {
                       RedisUtils.setsRemove(reportBatchSet,reportCode);
                       RedisUtils.setsRemove(pendingReportCode,reportCode);
                   }
                }
            });
        }

注意:企业应用还是慎用多线程。
遇到问题:
1、异步多线程中,业务代码中用到Request,获取url地址信息。当主线程中request返回时,已经无法使用原来的request,导致系统报错。
2、业务代码中用到获取当前登录用户信息。故有几个问题(1)线程一直存在,故获取当前用户是第一个开启线程的用户,导致后面获取当前线程时,获取当前登录用户不正确。(2)若线程存活时间超过登录超时时间,则超时之后获取不到当前用户。(3)若开启线程之后该用户退出,后面用到该线程会提示会话已经失效。(4)若采用一条失败,全部回滚,则需要小心,批处理数量一定要小于上限,否则会一直等待,所以上面工具类并没有“有上限同步回滚”方法。(5)若采用无上限多线程,容易超过数据库最大连接数。

附件:RedisLockUtils


@Component
public class RedisLockUtils {

    private static RedisTemplate redisTemplate;

    public static final String LOCKPREFIX = "LOCK_";
    public static final int DEFAULT_RETRY_COUNT = 10;
    public static final long DEFAULT_RETRY_TIME = 100;

    @Autowired
    private JedisConnectionFactory jedisConnectionFactory;

    @PostConstruct
    public void init() {
        this.redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(jedisConnectionFactory);
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        redisTemplate.setDefaultSerializer(stringRedisSerializer);
        redisTemplate.afterPropertiesSet();

        redisTemplate.delete(redisTemplate.keys(LOCKPREFIX+"*"));
    }


    /**
     * 获取锁
     * @param key   业务代码
     * @return
     * @throws InterruptedException
     */
    public static Boolean tryLock(String key) throws InterruptedException {
       return tryLock(key,DEFAULT_RETRY_COUNT,DEFAULT_RETRY_TIME);
    }
    /**
     * 获取锁
     * @param key   业务代码
     * @param count 重试次数
     * @return
     * @throws InterruptedException
     */
    public static Boolean tryLock(String key,int count,long retryTime) throws InterruptedException {
        Boolean lock = Boolean.FALSE;
        do{
            if(count < 0){
                return false;
            }
            lock = redisTemplate.opsForValue().setIfAbsent(LOCKPREFIX+key, Thread.currentThread().getName());
            Thread.sleep(retryTime);
            count--;
        }while (!lock);

        return lock;
    }


    public static void releaseLock(String key){
        redisTemplate.delete(LOCKPREFIX+key);
    }

}

RedisUtils


@Component
public class RedisUtils {

    private static RedisTemplate redisTemplate;

    @Autowired
    private JedisConnectionFactory jedisConnectionFactory;

    @PostConstruct
    public void init() {
        this.redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(jedisConnectionFactory);
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        redisTemplate.setDefaultSerializer(stringRedisSerializer);
        redisTemplate.afterPropertiesSet();
    }

    public static RedisTemplate getRedisTemplate(){
        return redisTemplate;
    }

    /**
     * 删除key
     * @param key
     */
    public static void delKey(Object key){
        redisTemplate.delete(key);
    }
    /**
     * 删除key
     * @param key
     */
    public static Boolean expire(String key, final long timeout, final TimeUnit unit) {
        return redisTemplate.expire(key,timeout,unit);
    }

    /**
     * 设置字符串key value
     * @param key
     * @param value
     */
    public static void setKey(String key, String value){
        redisTemplate.opsForValue().set(key, value);
    }

    /**
     * 设置过期时间的key value
     * @param key
     * @param value
     * @param time
     * @param timeUnit
     */
    public static void setKey(String key, String value, long time, TimeUnit timeUnit){
        redisTemplate.opsForValue().set(key, value, time, timeUnit);
    }

    /**
     * 获取key的值
     * @param key
     * @return
     */
    public static Object getKey(String key){
        return redisTemplate.opsForValue().get(key);
    }

    /**
     * 若不存在则设置,若存在返回false
     * @param key
     * @param value
     * @return
     */
    public static Boolean setNx(String key,String value){
        return redisTemplate.opsForValue().setIfAbsent(key, value);
    }

    /**
     * key值累加
     * @param key
     * @param delta
     * @return
     */
    public static Long increment(String key, long delta){
        return redisTemplate.opsForValue().increment(key, delta);
    }


    /**
     * set 增加
     * @param key
     * @param value
     */
    public static void setsAdd(String key, Object... value){
        redisTemplate.opsForSet().add(key,value);
    }

    /**
     * set 删除
     * @param key
     * @param value
     */
    public static void setsRemove(String key, String value){
        redisTemplate.opsForSet().remove(key,value);
    }

    /**
     * set 长度
     * @param key
     * @return
     */
    public static long setsCount(String key){
        return redisTemplate.opsForSet().size(key);
    }

    /**
     * set 内容
     * @param key
     * @return
     */
    public static Set getSets(String key){
        return redisTemplate.opsForSet().members(key);
    }


    /**
     * set 交集
     * @param key
     * @param otherKey
     * @return
     */
    public static Set setIntersect(String key, String otherKey){
        return redisTemplate.opsForSet().intersect(key,otherKey);
    }


    /**
     * set 交集
     * @param key   redis中的set
     * @param list  要比较的集合
     * @return
     */
    public static Set setIntersect(String key, Collection list){
       return setIntersect(key,list.toArray());
    }

    /**
     * set 交集
     * @param key   redis中的set
     * @param arrays  要比较的集合
     * @return
     */
    public static Set setIntersect(String key, Object[] arrays){
        String tempKey = "TMP_CMP_SET-"+ UUID.randomUUID().toString();
        RedisUtils.setsAdd(tempKey,arrays);
        Set intersectSet = RedisUtils.setIntersect(key,tempKey);
        delKey(tempKey);
        return intersectSet;
    }


    public static Boolean hasKey(String key){
        return redisTemplate.hasKey(key);
    }

}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值