公共线程池创建,公共批量执行方法执行类

单个带参

线程池


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author zhang
 * @since 2024-08-28 17:31
 */
public class CommonThreadPool {
    private static final Logger logger = LoggerFactory.getLogger(CommonThreadPool.class);
    /**
     * 核心线程数
     */
    private static final int CORE_POOL_SIZE = 100;
    /**
     * 最大线程数
     */
    private static final int MAX_POOL_SIZE = 500;
    /**
     * 空闲线程存活时间(秒)
     */
    private static final long KEEP_ALIVE_TIME = 60L;
    private static final TimeUnit UNIT = TimeUnit.SECONDS;
    private static final int QUEUE_CAPACITY = 5000;
    private static volatile ThreadPoolExecutor executor;

    private static final ThreadFactory THREAD_FACTORY = new ThreadFactory() {
        private final AtomicInteger threadNumber = new AtomicInteger(1);

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r, "ThreadPool-Thread-" + threadNumber.getAndIncrement());
            t.setDaemon(true);
            return t;
        }
    };

    private CommonThreadPool() {
        // 防止实例化
    }

    /**
     * 获取线程池(单例),使用完需要使用shutdown释放资源
     *
     * @param why 获取原因
     * @return 线程池
     */
    public static ThreadPoolExecutor getInstance(String why) {
        logger.info("正在为[{}]申请线程池", why);
        if (executor == null || executor.isShutdown()) {
            synchronized (CommonThreadPool.class) {
                if (executor == null || executor.isShutdown()) {
                    executor = new ThreadPoolExecutor(
                            CORE_POOL_SIZE,
                            MAX_POOL_SIZE,
                            KEEP_ALIVE_TIME,
                            UNIT,
                            new LinkedBlockingQueue<>(QUEUE_CAPACITY),
                            THREAD_FACTORY,
                            new ThreadPoolExecutor.DiscardOldestPolicy()
                    );
                    logger.info("[{}]申请线程池成功", why);
                }
            }
        }
        return executor;
    }

    /**
     * 执行批量的任务,每个任务都有返回值且相互独立
     *
     * @param tasks 任务集合
     * @param <T>   执行结果
     * @param timeout 超时时间
     * @return 执行结果集合
     * @throws InterruptedException InterruptedException
     */
    public static <T> List<Future<T>> invokeAll(ThreadPoolExecutor threadPoolExecutor, List<? extends Callable<T>> tasks, long timeout) throws InterruptedException {
        return threadPoolExecutor.invokeAll(tasks,timeout, TimeUnit.SECONDS);
    }

    /**
     * 释放线程池资源
     *
     * @param threadPoolExecutor ThreadPoolExecutor
     */
    public static void shutdown(ThreadPoolExecutor threadPoolExecutor, String why) {
        if (threadPoolExecutor != null && !threadPoolExecutor.isShutdown()) {
            threadPoolExecutor.shutdown();
            try {
                if (!threadPoolExecutor.awaitTermination(60L, TimeUnit.SECONDS)) {
                    threadPoolExecutor.shutdownNow();
                }
            } catch (InterruptedException e) {
                threadPoolExecutor.shutdownNow();
                Thread.currentThread().interrupt();
            }
        }
        logger.info("[{}]线程池资源正常释放", why);
    }

}

上下文获取工具类

import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * @author zhang
 * @since 2024-08-28 17:58
 */
@Component
public class ContextLocator implements ApplicationContextAware {
    private static ApplicationContext applicationContext;

    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) {
        ContextLocator.applicationContext = applicationContext;
    }

    public static Object getBean(String beanName) {
        return applicationContext.getBean(beanName);
    }
}

批量执行公共类


import org.springframework.context.ApplicationContext;

import java.util.concurrent.Callable;

/**
 * 多线程批量执行公共类
 * 该类带参只有一个,业务必须用到回调值慎用
 *
 * @param <T> 方法参数
 * @param <R> 方法返回类
 * @author zhang
 * @since 2024-08-28 17:47
 */
public class CommonCallable<T, R> implements Callable<R> {
    private final T data;
    private final Class<?> beanClass;
    private final String methodName;
    private final Class<R> returnType;

    /**
     * 有参构造
     *
     * @param data       方法参数
     * @param beanClass  方法所属类
     * @param methodName 方法名字
     * @param returnType 方法返回类
     */
    public CommonCallable(T data, Class<?> beanClass, String methodName, Class<R> returnType) {
        this.data = data;
        this.beanClass = beanClass;
        this.methodName = methodName;
        this.returnType = returnType;
    }

    /**
     * 线程执行
     *
     * @return 返回类
     * @throws Exception Exception
     */
    @Override
    public R call() throws Exception {
        ApplicationContext context = ContextLocator.getApplicationContext();
        Object bean = context.getBean(beanClass);
        return returnType.cast(bean.getClass().getMethod(methodName, data.getClass()).invoke(bean, data));
    }
}

测试公共返回类

@Data
public class CommonResult {
    private boolean success;
    private String code;
    private String message;
    private Object object;
}

测试业务类

import org.springframework.stereotype.Service;

/**
 * @author zhang
 * @since 2024-08-28 18:12
 */
@Service
public class StudentService {
    public CommonResult tall(Integer count) {
        boolean b = System.currentTimeMillis() % 2 == 0;
        return new CommonResult(b, "", b ? 0 : count, "");
    }
}

测试控制器

    /**
     * 测试线程池
     *
     * @return CommonResult
     */
    @GetMapping("/testThreads")
    public CommonResult testThread() {
        String name = UUID.randomUUID().toString();
        ThreadPoolExecutor test1 = CommonThreadPool.getInstance(name);
        List<CommonResult> a = new ArrayList<>();
        try {
            List<CommonCallable<Integer, CommonResult>> callableList = new ArrayList<>();
            for (int i = 0; i < 100; i++) {
                callableList.add(new CommonCallable<>(i, StudentService.class, "tall", CommonResult.class));
            }
            List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList,30L);
            for (Future<CommonResult> future : futures) {
                a.add(future.get());
            }
        } catch (Exception e) {
            // 记录异常
        } finally {
            CommonThreadPool.shutdown(test1, name);
        }
        return CommonResult.success(a);
    }

2024-08-29新增,多个带参情况

公共类(多参)


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;

import java.util.concurrent.Callable;

/**
 * 多线程批量执行公共类
 * 多个带参,回调类型必须与调用方法的回调类型一致,业务必须用到回调值慎用
 *
 * @param <R> 方法返回类
 * @author zhang
 * @since 2024-08-29 10:00
 */
public class CommonMoreCallable<R> implements Callable<R> {

    private final static Logger LOG = LoggerFactory.getLogger(CommonMoreCallable.class);
    private final Object[] data;
    private final Class<?>[] dataClass;
    private final Class<?> beanClass;
    private final String methodName;
    private final Class<R> returnType;

    /**
     * 有参构造
     *
     * @param data       方法参数
     * @param beanClass  方法所属类
     * @param methodName 方法名字
     * @param returnType 方法返回类
     */
    public CommonMoreCallable(Object[] data, Class<?>[] dataClass, Class<?> beanClass, String methodName, Class<R> returnType) {
        this.data = data;
        this.dataClass = dataClass;
        this.beanClass = beanClass;
        this.methodName = methodName;
        this.returnType = returnType;
    }

    /**
     * 线程执行
     *
     * @return 返回类
     * @throws Exception Exception
     */
    @Override
    public R call() throws Exception {
        if (data.length != dataClass.length) {
            LOG.info("参数数组长度和参数类型数组长度不一致");
            /*抛出异常后任务终止*/
            throw new IllegalArgumentException("参数数组长度和参数类型数组长度不一致");
        }
        for (int i = 0; i < data.length; i++) {
            if (data[i] != null && !dataClass[i].isInstance(data[i])) {
                LOG.info("参数和参数类型不一致");
                /*抛出异常后任务终止*/
                throw new IllegalArgumentException("参数和参数类型不一致");
            }
        }
        ApplicationContext context = ContextLocator.getApplicationContext();
        Object bean = context.getBean(beanClass);
        Object invoke = bean.getClass().getMethod(methodName, dataClass).invoke(bean, data);
        if (!returnType.isInstance(invoke)) {
            /*任务已经结束*/
            LOG.info("返回类型不匹配");
        }
        return returnType.cast(invoke);
    }
}

测试业务类

import org.springframework.stereotype.Service;

/**
 * @author zhang
 * @since 2024-08-29 10:00
 */
@Service
public class StudentMoreService {
    public CommonResult tall(Integer count, String name) {
        /*模拟业务*/
        final long l = System.currentTimeMillis();
        boolean b = l % 2L == 0;
        String code = count + "_" + name;
        System.out.println(code);
        return new CommonResult(b, "", b ? 0 : l, code);
    }
}

测试控制器

    /**
     * 测试线程池
     *
     * @return CommonResult
     */
    @GetMapping("/testThreads")
    public CommonResult testThread() {
        String name = UUID.randomUUID().toString();
        ThreadPoolExecutor test1 = CommonThreadPool.getInstance(name);
        List<CommonResult> a = new ArrayList<>();
        try {
//            List<CommonCallable<Integer, CommonResult>> callableList = new ArrayList<>();
//            for (int i = 0; i < 100; i++) {
//                callableList.add(new CommonCallable<>(i, StudentService.class, "tall", CommonResult.class));
//            }
//            List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList, 30L);
//            for (Future<CommonResult> future : futures) {
//                a.add(future.get());
//            }
            List<CommonMoreCallable<CommonResult>> callableList = new ArrayList<>();
            for (int i = 0; i < 100; i++) {
                callableList.add(new CommonMoreCallable<>(new Object[]{i, name}, new Class[]{Integer.class, String.class}, StudentMoreService.class, "tall", CommonResult.class));
            }
            List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList, 30L);
            for (Future<CommonResult> future : futures) {
                a.add(future.get());
            }
        } catch (Exception e) {
            // 记录异常
            LOG.error(e.getMessage());
        } finally {
            CommonThreadPool.shutdown(test1, name);
        }
        return CommonResult.success(a);
    }

写这个的想法来源于去年工作中遇到的一个问题:业务中有一个批量执行的操作,客服一次开多个浏览器窗口进行批量操作,每次操作有300+数据需要入库查库和请求外部接口的操作,导致应用其他接口请求超时(说白话就是应用差点崩了)。经过排查发现是使用了线程池ThreadPoolExecutor,然后使用invokeAll去执行。但是前人挖了两个坑,一个是申请资源在执行结束后未释放资源,一个是方法内进行for循环去执行这个所谓的批量操作。后来经过改造,测试后能达到预期效果。今年工作中又需要用到批量执行的一个操作,联想到去年的问题,能不能写一个公共类来代替繁琐的实现Callable接口,于是这篇文章就诞生了。若是有大佬发现文章内容有不对的地方劳烦指出,感谢。
附上之前改写的模拟业务代码


public class StudentCallable implements Callable<CommonResult> {
   StudentInfo studentInfo;

   public StudentCallable (StudentInfo studentInfo) {
      this.studentInfo= studentInfo;
   }

   @Override
   public CommonResult call() throws Exception {
      StudentManager studentManager = ContextLocator.getBean(StudentManager.class);
      return studentManager.upload(this.studentInfo);
   }
}

///		省略上面代码
    //执行操作
    List<StudentCallable> records = Lists.newArrayList();
    for(Student studentInfo : studentInfoList){
       records .add(new StudentCallable(studentInfo));
    }
   /*申请线程池资源*/
   ThreadPoolExecutor studentUp = CommonTreadPool.getInstance("业务");
   /*执行批量的任务*/
   List<Future<CommonResult>> results = studentUp.invokeAll(records, 60L, TimeUnit.SECONDS);
///		省略下面代码

最后优化

2024-08-31


import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.*;

/**
 * @author zhang
 * @since 2024-08-28 17:31
 */
public class CommonThreadPool {
    private static final Logger logger = LoggerFactory.getLogger(CommonThreadPool.class);
    private static final int CORE_POOL_SIZE = 100;
    private static final int MAX_POOL_SIZE = 500;
    private static final long KEEP_ALIVE_TIME = 60L;
    private static final TimeUnit UNIT = TimeUnit.SECONDS;
    private static final int QUEUE_CAPACITY = 5000;
    private static volatile ThreadPoolExecutor executor;

    private static final ThreadFactory THREAD_FACTORY = new ThreadFactoryBuilder()
            .setNameFormat("ThreadPool-Thread-%d")
            .setDaemon(true)
            .build();

    private CommonThreadPool() {
        // 防止实例化
    }

    public static ThreadPoolExecutor getInstance(String reason) {
        logger.info("Attempting to acquire thread pool for [{}]", reason);
        if (executor == null || executor.isShutdown()) {
            synchronized (CommonThreadPool.class) {
                if (executor == null || executor.isShutdown()) {
                    executor = new ThreadPoolExecutor(
                            CORE_POOL_SIZE,
                            MAX_POOL_SIZE,
                            KEEP_ALIVE_TIME,
                            UNIT,
                            new LinkedBlockingQueue<>(QUEUE_CAPACITY),
                            THREAD_FACTORY,
                            new ThreadPoolExecutor.CallerRunsPolicy()
                    );
                    logger.info("[{}] Acquired thread pool successfully", reason);
                }
            }
        }
        return executor;
    }

    public static <T> List<Future<T>> invokeAll(ThreadPoolExecutor threadPoolExecutor, List<? extends Callable<T>> tasks, long timeout) throws InterruptedException {
        return threadPoolExecutor.invokeAll(tasks, timeout, TimeUnit.SECONDS);
    }

    public static void shutdown(ThreadPoolExecutor threadPoolExecutor, String reason) {
        if (threadPoolExecutor != null && !threadPoolExecutor.isShutdown()) {
            threadPoolExecutor.shutdown();
            try {
                if (!threadPoolExecutor.awaitTermination(60L, TimeUnit.SECONDS)) {
                    threadPoolExecutor.shutdownNow();
                    logger.warn("[{}] Forced shutdown of the thread pool due to timeout", reason);
                }
            } catch (InterruptedException e) {
                threadPoolExecutor.shutdownNow();
                Thread.currentThread().interrupt();
            }
        }
        logger.info("[{}] Thread pool resources released normally", reason);
    }

    public static Method getMethod(Class<?> beanClass, String methodName, Class<?>... parameterTypes) {
        try {
            return beanClass.getMethod(methodName, parameterTypes);
        } catch (NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }
}


import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
import org.springframework.beans.BeansException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author zhang
 * @since 2024-08-28 17:58
 */
@Component
public class ContextLocator implements ApplicationContextAware {
    private static final Logger LOG = LoggerFactory.getLogger(ContextLocator.class);
    private static ApplicationContext applicationContext;

    public static ApplicationContext getApplicationContext() {
        if (applicationContext == null) {
            LOG.warn("ApplicationContext has not been initialized yet.");
        }
        return applicationContext;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) {
        if (ContextLocator.applicationContext == null) {
            synchronized (ContextLocator.class) {
                if (ContextLocator.applicationContext == null) {
                    ContextLocator.applicationContext = applicationContext;
                    LOG.info("ApplicationContext initialized successfully.");
                }
            }
        } else {
            LOG.warn("ApplicationContext was already initialized.");
        }
    }

    public static Object getBean(String beanName) {
        if (applicationContext == null) {
            LOG.error("ApplicationContext is not initialized. Cannot retrieve bean [{}].", beanName);
            throw new IllegalStateException("ApplicationContext is not initialized.");
        }
        try {
            return applicationContext.getBean(beanName);
        } catch (BeansException e) {
            LOG.error("Failed to retrieve bean [{}].", beanName, e);
            throw e;
        }
    }
}


import org.springframework.context.ApplicationContext;

import java.util.concurrent.Callable;

/**
 * 多线程批量执行公共类
 * 该类带参只有一个,业务必须用到回调值慎用
 *
 * @param <T> 方法参数
 * @param <R> 方法返回类
 * @author zhang
 * @since 2024-08-28 17:47
 */
public class CommonCallable<T, R> implements Callable<R> {
    private final T data;
    private final Class<?> beanClass;
    private final String methodName;
    private final Class<R> returnType;

    /**
     * 有参构造
     *
     * @param data       方法参数
     * @param beanClass  方法所属类
     * @param methodName 方法名字
     * @param returnType 方法返回类
     */
    public CommonCallable(T data, Class<?> beanClass, String methodName, Class<R> returnType) {
        this.data = data;
        this.beanClass = beanClass;
        this.methodName = methodName;
        this.returnType = returnType;
    }

    /**
     * 线程执行
     *
     * @return 返回类
     * @throws Exception Exception
     */
    @Override
    public R call() throws Exception {
        ApplicationContext context = ContextLocator.getApplicationContext();
        Object bean = context.getBean(beanClass);
        return returnType.cast(bean.getClass().getMethod(methodName, data.getClass()).invoke(bean, data));
    }
}


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.concurrent.Callable;

/**
 * 多线程批量执行公共类
 * 多个带参,回调类型必须与调用方法的回调类型一致,业务必须用到回调值慎用
 *
 * @param <R> 方法返回类
 * @author zhang
 * @since 2024-08-29 10:00
 */
public class CommonMoreCallable<R> implements Callable<R> {

    private final static Logger LOG = LoggerFactory.getLogger(CommonMoreCallable.class);
    private final Object[] data;
    private final Class<?>[] dataClass;
    private final Class<?> beanClass;
    private final String methodName;
    private final Class<R> returnType;

    /**
     * 有参构造
     *
     * @param data       方法参数
     * @param beanClass  方法所属类
     * @param methodName 方法名字
     * @param returnType 方法返回类
     */
    public CommonMoreCallable(Object[] data, Class<?>[] dataClass, Class<?> beanClass, String methodName, Class<R> returnType) {
        if (data == null || dataClass == null || beanClass == null || methodName == null || returnType == null) {
            throw new IllegalArgumentException("Arguments cannot be null");
        }
        if (data.length != dataClass.length) {
            throw new IllegalArgumentException("Data array length does not match data class array length");
        }
        this.data = data;
        this.dataClass = dataClass;
        this.beanClass = beanClass;
        this.methodName = methodName;
        this.returnType = returnType;
    }

    /**
     * 线程执行
     *
     * @return 返回类
     * @throws Exception Exception
     */
    @Override
    public R call() throws Exception {
        ApplicationContext context = ContextLocator.getApplicationContext();
        Object bean = context.getBean(beanClass);

        if (bean == null) {
            LOG.error("Bean with class {} not found in application context", beanClass.getName());
            throw new IllegalStateException("Bean not found in application context");
        }

        Method method;
        try {
            method = beanClass.getMethod(methodName, dataClass);
        } catch (NoSuchMethodException e) {
            LOG.error("Method {} with parameters {} not found in {}", methodName, Arrays.toString(dataClass), beanClass.getName(), e);
            throw new NoSuchMethodException("Method not found");
        }

        Object result;
        try {
            result = method.invoke(bean, data);
        } catch (Exception e) {
            LOG.error("Error invoking method {} on bean of type {}", methodName, beanClass.getName(), e);
            throw e;
        }

        if (!returnType.isInstance(result)) {
            LOG.error("Return value of method {} is not of expected type {}", methodName, returnType.getName());
            throw new ClassCastException("Return type mismatch");
        }

        return returnType.cast(result);
    }
}

使用示例
CommonMoreCallable使用示例

以下是一个简单的 Java 公共线程池,提供获取线程的方法: ```java import java.util.concurrent.*; public class ThreadPool { private static ThreadPool instance; private ExecutorService executor; private ThreadPool() { int corePoolSize = 10; int maxPoolSize = 100; long keepAliveTime = 60L; TimeUnit timeUnit = TimeUnit.SECONDS; BlockingQueue<Runnable> workQueue = new LinkedBlockingQueue<>(); executor = new ThreadPoolExecutor(corePoolSize, maxPoolSize, keepAliveTime, timeUnit, workQueue); } public static synchronized ThreadPool getInstance() { if (instance == null) { instance = new ThreadPool(); } return instance; } public void execute(Runnable task) { executor.execute(task); } public Future<?> submit(Runnable task) { return executor.submit(task); } public <T> Future<T> submit(Callable<T> task) { return executor.submit(task); } public void shutdown() { executor.shutdown(); } public boolean isShutdown() { return executor.isShutdown(); } public boolean isTerminated() { return executor.isTerminated(); } public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { return executor.awaitTermination(timeout, unit); } } ``` 这个使用了单例模式来确保只有一个线程池对象存在,可以通过 `ThreadPool.getInstance()` 方法获取该对象。线程池使用了 `ThreadPoolExecutor` 来实现,提供了 `execute()`、`submit()` 方法来提交任务,以及 `shutdown()`、`isShutdown()`、`isTerminated()`、`awaitTermination()` 方法来控制线程池的关闭和状态查询。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值