线程池封装
import net.trueland.seal.utils.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
public class MDCThreadPoolExecutor extends ThreadPoolExecutor {
private final Logger logger = LoggerFactory.getLogger(MDCThreadPoolExecutor.class);
// 线程名称前缀
private static final String THREADNAME_PREFIX = "employe_thread";
/**
* 线程池构造函数
*
* @param corePoolSize 核心线程数
* @param maximumPoolSize 最大线程数
* @param keepAliveTime 存活时间
* @param unit 时间单位
* 线程池自带的饱和策略:
* 1.AbortPolicy 直接抛异常阻止系统运行
* 2.CallerRunsPolicy 调用者运行
* 3.DiscardOldestPolicy 抛弃队列中等待时间最长的任务,然后将当前任务加入队里再次提交
* 4.DiscardPolicy 直接丢弃任务
*/
public MDCThreadPoolExecutor(
int corePoolSize,
int maximumPoolSize,
long keepAliveTime,
TimeUnit unit) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, new LinkedBlockingDeque<>(32), new EmployeeThreadFactory(), new ThreadPoolExecutor.CallerRunsPolicy());
}
@Override
public void execute(Runnable task) {
super.execute(ThreadUtils.wrap(task, MDC.getCopyOfContextMap()));
}
@Override
public Future<?> submit(Runnable task) {
return super.submit(ThreadUtils.wrap(task,MDC.getCopyOfContextMap()));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return super.submit(ThreadUtils.wrap(task,MDC.getCopyOfContextMap()), result);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return super.submit(ThreadUtils.wrap(task,MDC.getCopyOfContextMap()));
}
/**
* 自定义线程工厂
*/
private static class EmployeeThreadFactory implements ThreadFactory {
private AtomicInteger count = new AtomicInteger(0);
public EmployeeThreadFactory() {
}
@Override
public Thread newThread(Runnable r) {
Thread thread = new Thread(r);
String threadName = THREADNAME_PREFIX + count.addAndGet(1);
thread.setName(threadName);
return thread;
}
}
}
线程工具类
import net.trueland.seal.config.MDCThreadPoolExecutor;
import net.trueland.seal.constant.Constant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
public class ThreadUtils {
private static final Logger logger = LoggerFactory.getLogger (ThreadUtils.class);
private volatile static ThreadPoolExecutor threadPoolExecutor;
/**
* 获取线程池对象(双检索获取)
* @return
*/
public static ThreadPoolExecutor getInstance() {
if (threadPoolExecutor == null){
synchronized (ThreadUtils.class){
if (threadPoolExecutor == null){
threadPoolExecutor = createThreadPool();
}
}
}
return threadPoolExecutor;
}
/**
* 创建线程池对象
* @return
*/
private static ThreadPoolExecutor createThreadPool() {
logger.info("初始化线程池");
return new MDCThreadPoolExecutor(
8,
16,
60,
TimeUnit.SECONDS
);
}
/**
* 封装任务,加入TraceId,无返回值
* @param runnable
* @param threadContext
* @return
*/
public static Runnable wrap(final Runnable runnable, final Map<String,String> threadContext){
return new Runnable() {
@Override
public void run() {
if (threadContext == null){
MDC.clear();
}else{
MDC.setContextMap(threadContext);
}
setTraceIdIfAbsent();
try {
runnable.run();
} catch (Exception e) {
MDC.clear();
}
}
};
}
/**
* 封装任务,加入TraceId,有返回值
* @param callable
* @param threadContext
* @param <T>
* @return
*/
public static <T> Callable<T> wrap(final Callable<T> callable, final Map<String, String> threadContext) {
return new Callable<T>() {
@Override
public T call() throws Exception {
if (threadContext == null) {
MDC.clear();
} else {
MDC.setContextMap(threadContext);
}
setTraceIdIfAbsent();
try {
return callable.call();
} finally {
MDC.clear();
}
}
};
}
/**
* 如果traceId不存在,则设置一个随机的traceId
*/
public static void setTraceIdIfAbsent() {
if (MDC.get(Constant.TRACE_ID) == null) {
MDC.put(Constant.TRACE_ID, getRandomTraceId());
}
}
/**
* 生成随机的traceId
* @return
*/
public static String getRandomTraceId(){
return UUID.randomUUID().toString().replace("-","").toUpperCase();
}
}
拦截器拦截所有请求并设置MDC
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import net.trueland.seal.constant.Constant;
import net.trueland.seal.enums.RequestMethodEnum;
import net.trueland.seal.utils.ThreadUtils;
/**
* 日志拦截器,拦截所有请求添加MDC.traceId标志
*/
public class LogInterceptor implements HandlerInterceptor {
public static final Logger logger = LoggerFactory.getLogger(LogInterceptor.class);
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
// 预检请求直接放行
if (request.getMethod().toUpperCase().equals(RequestMethodEnum.OPTIONS.getMessage())){
return true;
}
logger.info(String.format("日志拦截器-拦截的URI: %s",request.getRequestURI()));
String traceId = ThreadUtils.getRandomTraceId();
MDC.put(Constant.TRACE_ID,traceId);
return true;
}
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
MDC.remove(Constant.TRACE_ID);
}
}