思路:
1、自定义有 HttpServletRequest 的异步的注解;
2、使用AOP对方法进行拦截处理:
a、在进入方法前复制Request我们所需要的信息;
b、创建多线程方法,在多线程中把Request复制类保存到线程变量中;
c、执行方法
具体代码:
自定义有 HttpServletRequest 的异步
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 自定义有 HttpServletRequest 的异步
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface HasReqAsync {
}
线程池
mport java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* 线程池
*/
@EnableAsync
@Configuration
@Setter
@ConfigurationProperties(prefix = "task.pool")
public class ThreadPoolConfig {
/**
* 线程池中的核心线程数量,默认为1
*/
private int corePoolSize = 5;
/**
* 线程池中的最大线程数量
*/
private int maxPoolSize = 10;
/**
* 线程池中允许线程的空闲时间,默认为 60s
*/
private int keepAliveTime = ((int) TimeUnit.SECONDS.toSeconds(60));
/**
* 线程池中的队列最大数量
*/
private int queueCapacity = 100;
/**
* 线程的名称前缀
*/
private static final String THREAD_PREFIX = "thread-call-runner-";
@Bean
@Lazy
public ThreadPoolTaskExecutor threadPool() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(corePoolSize);
executor.setMaxPoolSize(maxPoolSize);
executor.setKeepAliveSeconds(keepAliveTime);
executor.setQueueCapacity(queueCapacity);
executor.setThreadNamePrefix(THREAD_PREFIX);
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
executor.initialize();
return executor;
}
}
自定义线程变量
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.Nullable;
import javax.servlet.http.HttpServletRequest;
/**
* 自定义线程变量
*/
@Log4j2
public class MyRequestContextHolder {
private static final ThreadLocal<HttpServletRequest> local = new ThreadLocal<>();
public static void resetRequest() {
local.remove();
}
public static void setRequestAttributes(@Nullable HttpServletRequest request) {
local.set(request);
}
@Nullable
public static HttpServletRequest getRequest() {
return local.get();
}
}
自定义Request复制类
import cn.hutool.core.io.IoUtil;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
/**
* 自定义Request
*/
public class MyHttpServletRequestWrapper extends HttpServletRequestWrapper {
private byte[] body;
private final Map<String, String> headers;
private final Map<String, String[]> parameters;
public MyHttpServletRequestWrapper(HttpServletRequest request) {
super(request);
try {
body = IoUtil.readBytes(request.getInputStream());
} catch (IOException ex) {
body = new byte[0];
}
Enumeration<String> headerNames = request.getHeaderNames();
headers = new HashMap<>();
while (headerNames.hasMoreElements()) {
String key = headerNames.nextElement();
headers.put(key, request.getHeader(key));
}
Map<String, String[]> oldParameters = request.getParameterMap();
parameters = new HashMap<>();
parameters.putAll(oldParameters);
}
@Override
public ServletInputStream getInputStream() {
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
final ByteArrayInputStream bis = new ByteArrayInputStream(body);
@Override
public int read() throws IOException {
return bis.read();
}
};
}
@Override
public String getHeader(String name) {
return this.headers.get(name);
}
@Override
public String getParameter(String name) {
String[] val = getParameterValues(name);
if (val != null) {
return val[0];
}
return null;
}
@Override
public Map<String, String[]> getParameterMap() {
return this.parameters;
}
@Override
public String[] getParameterValues(String name) {
return this.parameters.get(name);
}
}
自定义异步处理逻辑
import com.example.demo2.异步.utils.ServletUtils;
import lombok.extern.log4j.Log4j2;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
/**
* 自定义异步
*/
@Log4j2
@Component
@Aspect
public class HasReqAsyncAop {
@Resource
ThreadPoolTaskExecutor threadPoolTaskExecutor;
//定义切入点
@Pointcut("@annotation(HasReqAsync) || @within(HasReqAsync)")
public void aspect() {
}
@Around("aspect()")
public Object around(ProceedingJoinPoint joinPoint) {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = methodSignature.getMethod();
log.debug("{}.{} 方法使用异步进行处理", method.getDeclaringClass().getName(), method.getName());
HttpServletRequest request = new MyHttpServletRequestWrapper(ServletUtils.getRequest());
Class<?> methodReturnType = method.getReturnType();
if (Future.class.isAssignableFrom(methodReturnType)) {
return CompletableFuture.supplyAsync(() -> {
try {
MyRequestContextHolder.setRequestAttributes(request);
return joinPoint.proceed();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}, threadPoolTaskExecutor).thenApply(e -> {
if (e instanceof CompletableFuture) {
try {
return ((CompletableFuture<?>) e).get();
} catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException(ex);
}
}
return e;
});
} else {
threadPoolTaskExecutor.execute(() -> {
try {
//上下文重新赋值,保证后续逻辑可以使用
MyRequestContextHolder.setRequestAttributes(request);
joinPoint.proceed();
} catch (Throwable e) {
throw new RuntimeException(e);
} finally {
MyRequestContextHolder.resetRequest();
}
});
return null;
}
}
}
ServletUtils工具类
import com.example.demo2.异步.config.MyRequestContextHolder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
@Slf4j
public class ServletUtils {
/**
* 获取String参数
*/
public static String getParameter(String name) {
return getRequest().getParameter(name);
}
/**
* 获取String参数
*/
public static String getHeader(String name) {
return getRequest().getHeader(name);
}
/**
* 获取request
*/
public static HttpServletRequest getRequest() {
HttpServletRequest request = MyRequestContextHolder.getRequest();
if (request == null) {
request = getRequestAttributes().getRequest();
}
return request;
}
/**
* 获取response
*/
public static HttpServletResponse getResponse() {
return getRequestAttributes().getResponse();
}
/**
* 获取session
*/
public static HttpSession getSession() {
return getRequest().getSession();
}
public static ServletRequestAttributes getRequestAttributes() {
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
return (ServletRequestAttributes) attributes;
}
/**
* 内容编码
*
* @param str 内容
* @return 编码后的内容
*/
public static String urlEncode(String str) {
try {
return URLEncoder.encode(str, StandardCharsets.UTF_8.name());
} catch (UnsupportedEncodingException e) {
log.info("内容编码失败:", e);
return "";
}
}
/**
* 内容解码
*
* @param str 内容
* @return 解码后的内容
*/
public static String urlDecode(String str) {
try {
return URLDecoder.decode(str, StandardCharsets.UTF_8.name());
} catch (UnsupportedEncodingException e) {
log.info("内容解码失败:", e);
return "";
}
}
}
代码测试:
import com.example.demo2.异步.config.HasReqAsync;
import com.example.demo2.异步.utils.ServletUtils;
import lombok.extern.log4j.Log4j2;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
@Log4j2
@Service
@HasReqAsync
public class TestService {
@Resource
private Test2Service test2Service1;
@Resource
private Test2Service test2Service2;
public void test() {
String name = ServletUtils.getParameter("name");
log.info("{} name==>{}", Thread.currentThread().getName(), name);
CompletableFuture<String> infoCaseVoFuture1 = test2Service1.test();
CompletableFuture<String> infoCaseVoFuture2 = test2Service2.test();
try {
log.info("test2Service1 返回结果 ==>{}", infoCaseVoFuture1.get());
log.info("test2Service2 返回结果 ==>{}", infoCaseVoFuture2.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}
import com.example.demo2.异步.config.HasReqAsync;
import com.example.demo2.异步.utils.ServletUtils;
import lombok.extern.log4j.Log4j2;
import org.springframework.stereotype.Service;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
@Log4j2
@Service
public class Test2Service {
@HasReqAsync
public CompletableFuture<String> test() {
System.out.println("进入");
Random ra = new Random();
int i = ra.nextInt(5000) + 1;
try {
System.out.println(Thread.currentThread().getName() + " 暂停时间:" + i + "毫秒");
Thread.sleep(i);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println(Thread.currentThread().getName() + " " + ServletUtils.getParameter("name"));
return CompletableFuture.completedFuture(Thread.currentThread().getName() + " " + System.currentTimeMillis());
}
}
import com.example.demo2.异步.service.TestService;
import com.example.demo2.异步.utils.ServletUtils;
import lombok.extern.log4j.Log4j2;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;
import javax.annotation.Resource;
@Log4j2
@Controller
public class TestController {
@Resource
private TestService testService;
// http://127.0.0.1:8888/test?name=lisi
@RequestMapping("/test")
@ResponseBody
public String test(@RequestParam(name = "name", defaultValue = "unknown user") String name) {
log.info("BasicController name==>{}",name);
String name2 = ServletUtils.getParameter("name");
log.info("{} name2==>{}",Thread.currentThread().getName(), name2);
testService.test();
log.info("主线程执行完成");
return "主线程执行完成 " + name;
}
}
dome地址,可下载执行运行查看逻辑
该异步可以在异步中获取Request信息,并且支持获取异步返回结果