单个带参
线程池
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author zhang
* @since 2024-08-28 17:31
*/
public class CommonThreadPool {
private static final Logger logger = LoggerFactory.getLogger(CommonThreadPool.class);
/**
* 核心线程数
*/
private static final int CORE_POOL_SIZE = 100;
/**
* 最大线程数
*/
private static final int MAX_POOL_SIZE = 500;
/**
* 空闲线程存活时间(秒)
*/
private static final long KEEP_ALIVE_TIME = 60L;
private static final TimeUnit UNIT = TimeUnit.SECONDS;
private static final int QUEUE_CAPACITY = 5000;
private static volatile ThreadPoolExecutor executor;
private static final ThreadFactory THREAD_FACTORY = new ThreadFactory() {
private final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, "ThreadPool-Thread-" + threadNumber.getAndIncrement());
t.setDaemon(true);
return t;
}
};
private CommonThreadPool() {
// 防止实例化
}
/**
* 获取线程池(单例),使用完需要使用shutdown释放资源
*
* @param why 获取原因
* @return 线程池
*/
public static ThreadPoolExecutor getInstance(String why) {
logger.info("正在为[{}]申请线程池", why);
if (executor == null || executor.isShutdown()) {
synchronized (CommonThreadPool.class) {
if (executor == null || executor.isShutdown()) {
executor = new ThreadPoolExecutor(
CORE_POOL_SIZE,
MAX_POOL_SIZE,
KEEP_ALIVE_TIME,
UNIT,
new LinkedBlockingQueue<>(QUEUE_CAPACITY),
THREAD_FACTORY,
new ThreadPoolExecutor.DiscardOldestPolicy()
);
logger.info("[{}]申请线程池成功", why);
}
}
}
return executor;
}
/**
* 执行批量的任务,每个任务都有返回值且相互独立
*
* @param tasks 任务集合
* @param <T> 执行结果
* @param timeout 超时时间
* @return 执行结果集合
* @throws InterruptedException InterruptedException
*/
public static <T> List<Future<T>> invokeAll(ThreadPoolExecutor threadPoolExecutor, List<? extends Callable<T>> tasks, long timeout) throws InterruptedException {
return threadPoolExecutor.invokeAll(tasks,timeout, TimeUnit.SECONDS);
}
/**
* 释放线程池资源
*
* @param threadPoolExecutor ThreadPoolExecutor
*/
public static void shutdown(ThreadPoolExecutor threadPoolExecutor, String why) {
if (threadPoolExecutor != null && !threadPoolExecutor.isShutdown()) {
threadPoolExecutor.shutdown();
try {
if (!threadPoolExecutor.awaitTermination(60L, TimeUnit.SECONDS)) {
threadPoolExecutor.shutdownNow();
}
} catch (InterruptedException e) {
threadPoolExecutor.shutdownNow();
Thread.currentThread().interrupt();
}
}
logger.info("[{}]线程池资源正常释放", why);
}
}
上下文获取工具类
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
/**
* @author zhang
* @since 2024-08-28 17:58
*/
@Component
public class ContextLocator implements ApplicationContextAware {
private static ApplicationContext applicationContext;
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) {
ContextLocator.applicationContext = applicationContext;
}
public static Object getBean(String beanName) {
return applicationContext.getBean(beanName);
}
}
批量执行公共类
import org.springframework.context.ApplicationContext;
import java.util.concurrent.Callable;
/**
* 多线程批量执行公共类
* 该类带参只有一个,业务必须用到回调值慎用
*
* @param <T> 方法参数
* @param <R> 方法返回类
* @author zhang
* @since 2024-08-28 17:47
*/
public class CommonCallable<T, R> implements Callable<R> {
private final T data;
private final Class<?> beanClass;
private final String methodName;
private final Class<R> returnType;
/**
* 有参构造
*
* @param data 方法参数
* @param beanClass 方法所属类
* @param methodName 方法名字
* @param returnType 方法返回类
*/
public CommonCallable(T data, Class<?> beanClass, String methodName, Class<R> returnType) {
this.data = data;
this.beanClass = beanClass;
this.methodName = methodName;
this.returnType = returnType;
}
/**
* 线程执行
*
* @return 返回类
* @throws Exception Exception
*/
@Override
public R call() throws Exception {
ApplicationContext context = ContextLocator.getApplicationContext();
Object bean = context.getBean(beanClass);
return returnType.cast(bean.getClass().getMethod(methodName, data.getClass()).invoke(bean, data));
}
}
测试公共返回类
@Data
public class CommonResult {
private boolean success;
private String code;
private String message;
private Object object;
}
测试业务类
import org.springframework.stereotype.Service;
/**
* @author zhang
* @since 2024-08-28 18:12
*/
@Service
public class StudentService {
public CommonResult tall(Integer count) {
boolean b = System.currentTimeMillis() % 2 == 0;
return new CommonResult(b, "", b ? 0 : count, "");
}
}
测试控制器
/**
* 测试线程池
*
* @return CommonResult
*/
@GetMapping("/testThreads")
public CommonResult testThread() {
String name = UUID.randomUUID().toString();
ThreadPoolExecutor test1 = CommonThreadPool.getInstance(name);
List<CommonResult> a = new ArrayList<>();
try {
List<CommonCallable<Integer, CommonResult>> callableList = new ArrayList<>();
for (int i = 0; i < 100; i++) {
callableList.add(new CommonCallable<>(i, StudentService.class, "tall", CommonResult.class));
}
List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList,30L);
for (Future<CommonResult> future : futures) {
a.add(future.get());
}
} catch (Exception e) {
// 记录异常
} finally {
CommonThreadPool.shutdown(test1, name);
}
return CommonResult.success(a);
}
2024-08-29新增,多个带参情况
公共类(多参)
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import java.util.concurrent.Callable;
/**
* 多线程批量执行公共类
* 多个带参,回调类型必须与调用方法的回调类型一致,业务必须用到回调值慎用
*
* @param <R> 方法返回类
* @author zhang
* @since 2024-08-29 10:00
*/
public class CommonMoreCallable<R> implements Callable<R> {
private final static Logger LOG = LoggerFactory.getLogger(CommonMoreCallable.class);
private final Object[] data;
private final Class<?>[] dataClass;
private final Class<?> beanClass;
private final String methodName;
private final Class<R> returnType;
/**
* 有参构造
*
* @param data 方法参数
* @param beanClass 方法所属类
* @param methodName 方法名字
* @param returnType 方法返回类
*/
public CommonMoreCallable(Object[] data, Class<?>[] dataClass, Class<?> beanClass, String methodName, Class<R> returnType) {
this.data = data;
this.dataClass = dataClass;
this.beanClass = beanClass;
this.methodName = methodName;
this.returnType = returnType;
}
/**
* 线程执行
*
* @return 返回类
* @throws Exception Exception
*/
@Override
public R call() throws Exception {
if (data.length != dataClass.length) {
LOG.info("参数数组长度和参数类型数组长度不一致");
/*抛出异常后任务终止*/
throw new IllegalArgumentException("参数数组长度和参数类型数组长度不一致");
}
for (int i = 0; i < data.length; i++) {
if (data[i] != null && !dataClass[i].isInstance(data[i])) {
LOG.info("参数和参数类型不一致");
/*抛出异常后任务终止*/
throw new IllegalArgumentException("参数和参数类型不一致");
}
}
ApplicationContext context = ContextLocator.getApplicationContext();
Object bean = context.getBean(beanClass);
Object invoke = bean.getClass().getMethod(methodName, dataClass).invoke(bean, data);
if (!returnType.isInstance(invoke)) {
/*任务已经结束*/
LOG.info("返回类型不匹配");
}
return returnType.cast(invoke);
}
}
测试业务类
import org.springframework.stereotype.Service;
/**
* @author zhang
* @since 2024-08-29 10:00
*/
@Service
public class StudentMoreService {
public CommonResult tall(Integer count, String name) {
/*模拟业务*/
final long l = System.currentTimeMillis();
boolean b = l % 2L == 0;
String code = count + "_" + name;
System.out.println(code);
return new CommonResult(b, "", b ? 0 : l, code);
}
}
测试控制器
/**
* 测试线程池
*
* @return CommonResult
*/
@GetMapping("/testThreads")
public CommonResult testThread() {
String name = UUID.randomUUID().toString();
ThreadPoolExecutor test1 = CommonThreadPool.getInstance(name);
List<CommonResult> a = new ArrayList<>();
try {
// List<CommonCallable<Integer, CommonResult>> callableList = new ArrayList<>();
// for (int i = 0; i < 100; i++) {
// callableList.add(new CommonCallable<>(i, StudentService.class, "tall", CommonResult.class));
// }
// List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList, 30L);
// for (Future<CommonResult> future : futures) {
// a.add(future.get());
// }
List<CommonMoreCallable<CommonResult>> callableList = new ArrayList<>();
for (int i = 0; i < 100; i++) {
callableList.add(new CommonMoreCallable<>(new Object[]{i, name}, new Class[]{Integer.class, String.class}, StudentMoreService.class, "tall", CommonResult.class));
}
List<Future<CommonResult>> futures = CommonThreadPool.invokeAll(test1, callableList, 30L);
for (Future<CommonResult> future : futures) {
a.add(future.get());
}
} catch (Exception e) {
// 记录异常
LOG.error(e.getMessage());
} finally {
CommonThreadPool.shutdown(test1, name);
}
return CommonResult.success(a);
}
写这个的想法来源于去年工作中遇到的一个问题:业务中有一个批量执行的操作,客服一次开多个浏览器窗口进行批量操作,每次操作有300+数据需要入库查库和请求外部接口的操作,导致应用其他接口请求超时(说白话就是应用差点崩了)。经过排查发现是使用了线程池ThreadPoolExecutor,然后使用invokeAll去执行。但是前人挖了两个坑,一个是申请资源在执行结束后未释放资源,一个是方法内进行for循环去执行这个所谓的批量操作。后来经过改造,测试后能达到预期效果。今年工作中又需要用到批量执行的一个操作,联想到去年的问题,能不能写一个公共类来代替繁琐的实现Callable接口,于是这篇文章就诞生了。若是有大佬发现文章内容有不对的地方劳烦指出,感谢。
附上之前改写的模拟业务代码
public class StudentCallable implements Callable<CommonResult> {
StudentInfo studentInfo;
public StudentCallable (StudentInfo studentInfo) {
this.studentInfo= studentInfo;
}
@Override
public CommonResult call() throws Exception {
StudentManager studentManager = ContextLocator.getBean(StudentManager.class);
return studentManager.upload(this.studentInfo);
}
}
/// 省略上面代码
//执行操作
List<StudentCallable> records = Lists.newArrayList();
for(Student studentInfo : studentInfoList){
records .add(new StudentCallable(studentInfo));
}
/*申请线程池资源*/
ThreadPoolExecutor studentUp = CommonTreadPool.getInstance("业务");
/*执行批量的任务*/
List<Future<CommonResult>> results = studentUp.invokeAll(records, 60L, TimeUnit.SECONDS);
/// 省略下面代码
最后优化
2024-08-31
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.*;
/**
* @author zhang
* @since 2024-08-28 17:31
*/
public class CommonThreadPool {
private static final Logger logger = LoggerFactory.getLogger(CommonThreadPool.class);
private static final int CORE_POOL_SIZE = 100;
private static final int MAX_POOL_SIZE = 500;
private static final long KEEP_ALIVE_TIME = 60L;
private static final TimeUnit UNIT = TimeUnit.SECONDS;
private static final int QUEUE_CAPACITY = 5000;
private static volatile ThreadPoolExecutor executor;
private static final ThreadFactory THREAD_FACTORY = new ThreadFactoryBuilder()
.setNameFormat("ThreadPool-Thread-%d")
.setDaemon(true)
.build();
private CommonThreadPool() {
// 防止实例化
}
public static ThreadPoolExecutor getInstance(String reason) {
logger.info("Attempting to acquire thread pool for [{}]", reason);
if (executor == null || executor.isShutdown()) {
synchronized (CommonThreadPool.class) {
if (executor == null || executor.isShutdown()) {
executor = new ThreadPoolExecutor(
CORE_POOL_SIZE,
MAX_POOL_SIZE,
KEEP_ALIVE_TIME,
UNIT,
new LinkedBlockingQueue<>(QUEUE_CAPACITY),
THREAD_FACTORY,
new ThreadPoolExecutor.CallerRunsPolicy()
);
logger.info("[{}] Acquired thread pool successfully", reason);
}
}
}
return executor;
}
public static <T> List<Future<T>> invokeAll(ThreadPoolExecutor threadPoolExecutor, List<? extends Callable<T>> tasks, long timeout) throws InterruptedException {
return threadPoolExecutor.invokeAll(tasks, timeout, TimeUnit.SECONDS);
}
public static void shutdown(ThreadPoolExecutor threadPoolExecutor, String reason) {
if (threadPoolExecutor != null && !threadPoolExecutor.isShutdown()) {
threadPoolExecutor.shutdown();
try {
if (!threadPoolExecutor.awaitTermination(60L, TimeUnit.SECONDS)) {
threadPoolExecutor.shutdownNow();
logger.warn("[{}] Forced shutdown of the thread pool due to timeout", reason);
}
} catch (InterruptedException e) {
threadPoolExecutor.shutdownNow();
Thread.currentThread().interrupt();
}
}
logger.info("[{}] Thread pool resources released normally", reason);
}
public static Method getMethod(Class<?> beanClass, String methodName, Class<?>... parameterTypes) {
try {
return beanClass.getMethod(methodName, parameterTypes);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
}
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
import org.springframework.beans.BeansException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author zhang
* @since 2024-08-28 17:58
*/
@Component
public class ContextLocator implements ApplicationContextAware {
private static final Logger LOG = LoggerFactory.getLogger(ContextLocator.class);
private static ApplicationContext applicationContext;
public static ApplicationContext getApplicationContext() {
if (applicationContext == null) {
LOG.warn("ApplicationContext has not been initialized yet.");
}
return applicationContext;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) {
if (ContextLocator.applicationContext == null) {
synchronized (ContextLocator.class) {
if (ContextLocator.applicationContext == null) {
ContextLocator.applicationContext = applicationContext;
LOG.info("ApplicationContext initialized successfully.");
}
}
} else {
LOG.warn("ApplicationContext was already initialized.");
}
}
public static Object getBean(String beanName) {
if (applicationContext == null) {
LOG.error("ApplicationContext is not initialized. Cannot retrieve bean [{}].", beanName);
throw new IllegalStateException("ApplicationContext is not initialized.");
}
try {
return applicationContext.getBean(beanName);
} catch (BeansException e) {
LOG.error("Failed to retrieve bean [{}].", beanName, e);
throw e;
}
}
}
import org.springframework.context.ApplicationContext;
import java.util.concurrent.Callable;
/**
* 多线程批量执行公共类
* 该类带参只有一个,业务必须用到回调值慎用
*
* @param <T> 方法参数
* @param <R> 方法返回类
* @author zhang
* @since 2024-08-28 17:47
*/
public class CommonCallable<T, R> implements Callable<R> {
private final T data;
private final Class<?> beanClass;
private final String methodName;
private final Class<R> returnType;
/**
* 有参构造
*
* @param data 方法参数
* @param beanClass 方法所属类
* @param methodName 方法名字
* @param returnType 方法返回类
*/
public CommonCallable(T data, Class<?> beanClass, String methodName, Class<R> returnType) {
this.data = data;
this.beanClass = beanClass;
this.methodName = methodName;
this.returnType = returnType;
}
/**
* 线程执行
*
* @return 返回类
* @throws Exception Exception
*/
@Override
public R call() throws Exception {
ApplicationContext context = ContextLocator.getApplicationContext();
Object bean = context.getBean(beanClass);
return returnType.cast(bean.getClass().getMethod(methodName, data.getClass()).invoke(bean, data));
}
}
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.concurrent.Callable;
/**
* 多线程批量执行公共类
* 多个带参,回调类型必须与调用方法的回调类型一致,业务必须用到回调值慎用
*
* @param <R> 方法返回类
* @author zhang
* @since 2024-08-29 10:00
*/
public class CommonMoreCallable<R> implements Callable<R> {
private final static Logger LOG = LoggerFactory.getLogger(CommonMoreCallable.class);
private final Object[] data;
private final Class<?>[] dataClass;
private final Class<?> beanClass;
private final String methodName;
private final Class<R> returnType;
/**
* 有参构造
*
* @param data 方法参数
* @param beanClass 方法所属类
* @param methodName 方法名字
* @param returnType 方法返回类
*/
public CommonMoreCallable(Object[] data, Class<?>[] dataClass, Class<?> beanClass, String methodName, Class<R> returnType) {
if (data == null || dataClass == null || beanClass == null || methodName == null || returnType == null) {
throw new IllegalArgumentException("Arguments cannot be null");
}
if (data.length != dataClass.length) {
throw new IllegalArgumentException("Data array length does not match data class array length");
}
this.data = data;
this.dataClass = dataClass;
this.beanClass = beanClass;
this.methodName = methodName;
this.returnType = returnType;
}
/**
* 线程执行
*
* @return 返回类
* @throws Exception Exception
*/
@Override
public R call() throws Exception {
ApplicationContext context = ContextLocator.getApplicationContext();
Object bean = context.getBean(beanClass);
if (bean == null) {
LOG.error("Bean with class {} not found in application context", beanClass.getName());
throw new IllegalStateException("Bean not found in application context");
}
Method method;
try {
method = beanClass.getMethod(methodName, dataClass);
} catch (NoSuchMethodException e) {
LOG.error("Method {} with parameters {} not found in {}", methodName, Arrays.toString(dataClass), beanClass.getName(), e);
throw new NoSuchMethodException("Method not found");
}
Object result;
try {
result = method.invoke(bean, data);
} catch (Exception e) {
LOG.error("Error invoking method {} on bean of type {}", methodName, beanClass.getName(), e);
throw e;
}
if (!returnType.isInstance(result)) {
LOG.error("Return value of method {} is not of expected type {}", methodName, returnType.getName());
throw new ClassCastException("Return type mismatch");
}
return returnType.cast(result);
}
}
使用示例