spring异步@Async方法request丢失的问题处理

文章讲述了在Spring框架中,由于主线程与异步线程对request的处理问题,提出了一种创建自定义HttpServletRequest对象并使用RequestContextHolder替换线程request的方法,以避免资源冲突和性能问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

spring的request实际上是重复利用的,在调用异步进程的时候,主线程先结束就会回收request,而异步线程中将会失去request或者获取其他请求的request。这明显会出现很多问题。

此时有两种解决方式,一是在主线程中调用request.startAsync,然后再子线程中调用asyncContext.complete方法,但是这会导致request被占用,有性能的问题。使用异步本来就是为了耗时的操作,所以不考虑。

二是创建一个自定义的request传入异步线程中,这样就不会跟其他请求相互影响。

创建自定义HttpServletRequest对象

import org.springframework.util.Assert;

import javax.servlet.*;
import javax.servlet.http.*;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.Principal;
import java.util.*;

public class MyHttpServletRequest implements HttpServletRequest {
    private Map<String, Object> attributes = new HashMap<String, Object>();
    private String servletPath;
    private StringBuffer requestURL;
    private String contextPath;
    private String requestURI;
    private Cookie[] cookies;
    private String characterEncoding;
    private int ContentLength = 0;
    private  long contentLengthLong = 0;
    private  String ContentType;
    Map<String, String[]> parameterMap = new HashMap<>();
    Map<String, List<String>> headMap = new HashMap<>();
    private String remoteAddr;
    private String remoteUser;
    private String remoteHost;
    private int remotePort;
    private String localAddr;
    private String localName;
    private int localPort;
    private String serverName;
    private int serverPort;
    private String method;
    private String protocol;
    private String scheme;
    private String pathInfo;
    private String pathTranslated;
    private String queryString;

    public MyHttpServletRequest(HttpServletRequest request) {
        //复制一遍对象,原来的request失效时会把对象清空
        Enumeration<String> attributeNames = request.getAttributeNames();
        while(attributeNames.hasMoreElements()) {
            String name = attributeNames.nextElement();
            this.attributes.put(name, request.getAttribute(name));
        }
        Enumeration<String> parameterNames = request.getParameterNames();
        while(parameterNames.hasMoreElements()) {
            String name = parameterNames.nextElement();
            this.parameterMap.put(name, request.getParameterValues(name));
        }
        //复制Cookie
        this.cookies = new Cookie[request.getCookies().length];
        for (int i = 0;i<request.getCookies().length;i++) {
            Cookie cookie = request.getCookies()[i];
            this.cookies[i] = copyCookie(cookie);
        }
        this.servletPath = request.getServletPath();
        this.requestURL = request.getRequestURL();
        this.requestURI = request.getRequestURI();
        this.contextPath = request.getContextPath();
        this.remoteAddr = request.getRemoteAddr();
        this.remoteUser = request.getRemoteUser();
        this.remoteHost = request.getRemoteHost();
        this.remotePort = request.getRemotePort();
        this.localAddr = request.getLocalAddr();
        this.localName = request.getLocalName();
        this.localPort = request.getLocalPort();
        this.characterEncoding = request.getCharacterEncoding();
        this.ContentLength = request.getContentLength();
        this.contentLengthLong = request.getContentLengthLong();
        this.ContentType = request.getContentType();
        Enumeration<String> headerName = request.getHeaderNames();
        while(headerName.hasMoreElements()) {
            String name = headerName.nextElement();
            this.headMap.put(name, Collections.list(request.getHeaders(name)));
        }
        this.serverName = request.getServerName();
        this.serverPort = request.getServerPort();
        this.method = request.getMethod();
        this.protocol = request.getProtocol();
        this.scheme = request.getScheme();
        this.pathInfo = request.getPathInfo();
        this.pathTranslated = request.getPathTranslated();
        this.queryString = request.getQueryString();
    }

    private Enumeration<String> copyEnumeration(Enumeration<String> headerName) {
        Vector<String> copiedVector = new Vector<String>();
        while (headerName.hasMoreElements()) {
            copiedVector.add(headerName.nextElement());
        }
        return copiedVector.elements();
    }
    public static Cookie copyCookie(Cookie oldCookie) {
        Cookie newCookie = new Cookie(oldCookie.getName(), oldCookie.getValue());
        newCookie.setMaxAge(oldCookie.getMaxAge());
        newCookie.setPath(oldCookie.getPath());
        if(oldCookie.getDomain() != null){
            newCookie.setDomain(oldCookie.getDomain());
        }
        newCookie.setSecure(oldCookie.getSecure());
        newCookie.setHttpOnly(oldCookie.isHttpOnly());
        return newCookie;
    }

    @Override
    public Object getAttribute(String name) {
        return attributes.get(name);
    }

    @Override
    public Enumeration<String> getAttributeNames() {
        // TODO Auto-generated method stub
        return Collections.enumeration(this.attributes.keySet());
    }

    @Override
    public String getCharacterEncoding() {
        // TODO Auto-generated method stub
        return this.characterEncoding;
    }

    @Override
    public void setCharacterEncoding(String env) throws UnsupportedEncodingException {
        this.characterEncoding = env;

    }

    @Override
    public int getContentLength() {
        return this.ContentLength;
    }

    @Override
    public long getContentLengthLong() {
        return this.contentLengthLong;
    }

    @Override
    public String getContentType() {
        return this.ContentType;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        // TODO Auto-generated method stub
        return null;
    }


    public void addParameter(String name, String... values) {
        Assert.notNull(name, "Parameter name must not be null");
        String[] oldArr = this.parameterMap.get(name);
        if (oldArr != null) {
            String[] newArr = new String[oldArr.length + values.length];
            System.arraycopy(oldArr, 0, newArr, 0, oldArr.length);
            System.arraycopy(values, 0, newArr, oldArr.length, values.length);
            this.parameterMap.put(name, newArr);
        }
        else {
            this.parameterMap.put(name, values);
        }
    }

    @Override
    public String getParameter(String name) {
        String[] arr = (name != null ? this.parameterMap.get(name) : null);
        return (arr != null && arr.length > 0 ? arr[0] : null);
    }

    @Override
    public Enumeration<String> getParameterNames() {
        // TODO Auto-generated method stub
        return Collections.enumeration(this.parameterMap.keySet());
    }

    @Override
    public String[] getParameterValues(String name) {
        // TODO Auto-generated method stub
        return parameterMap.get(name);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        // TODO Auto-generated method stub
        return parameterMap;
    }

    @Override
    public String getProtocol() {
        return this.protocol;
    }

    @Override
    public String getScheme() {
        return this.scheme;
    }

    @Override
    public String getServerName() {
        return this.serverName;
    }

    @Override
    public int getServerPort() {
        return this.serverPort;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String getRemoteAddr() {
        return this.remoteAddr;
    }

    @Override
    public String getRemoteHost() {
        return this.remoteHost;
    }

    @Override
    public void setAttribute(String name, Object o) {
        attributes.put(name, o);
    }

    @Override
    public void removeAttribute(String name) {
        attributes.remove(name);

    }

    @Override
    public Locale getLocale() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public Enumeration<Locale> getLocales() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public boolean isSecure() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public RequestDispatcher getRequestDispatcher(String path) {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String getRealPath(String path) {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public int getRemotePort() {
        return this.remotePort;
    }

    @Override
    public String getLocalName() {
        return this.localName;
    }

    @Override
    public String getLocalAddr() {
        return this.localAddr;
    }

    @Override
    public int getLocalPort() {
        return this.localPort;
    }

    @Override
    public ServletContext getServletContext() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public AsyncContext startAsync() throws IllegalStateException {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
            throws IllegalStateException {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public boolean isAsyncStarted() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean isAsyncSupported() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public AsyncContext getAsyncContext() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public DispatcherType getDispatcherType() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String getAuthType() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public Cookie[] getCookies() {
        return cookies;
    }

    @Override
    public long getDateHeader(String name) {
        // TODO Auto-generated method stub
        return 0;
    }

    @Override
    public String getHeader(String name) {
        List<String> list = this.headMap.get(name);
        return list != null && !list.isEmpty()?list.get(0):null;
    }

    @Override
    public Enumeration<String> getHeaders(String name) {
        return Collections.enumeration(this.headMap.get(name));
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        return Collections.enumeration(this.headMap.keySet());
    }

    @Override
    public int getIntHeader(String name) {
        // TODO Auto-generated method stub
        return 0;
    }

    @Override
    public String getMethod() {
        return this.method;
    }

    @Override
    public String getPathInfo() {
        return this.pathInfo;
    }

    @Override
    public String getPathTranslated() {
        return this.pathTranslated;
    }

    @Override
    public String getContextPath() {
        return this.contextPath;
    }

    @Override
    public String getQueryString() {
        return this.queryString;
    }

    @Override
    public String getRemoteUser() {
        return this.remoteUser;
    }

    @Override
    public boolean isUserInRole(String role) {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public Principal getUserPrincipal() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String getRequestedSessionId() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String getRequestURI() {
        return this.requestURI;
    }

    @Override
    public StringBuffer getRequestURL() {
        return this.requestURL;
    }

    @Override
    public String getServletPath() {
        return servletPath;
    }

    @Override
    public HttpSession getSession(boolean create) {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public HttpSession getSession() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public String changeSessionId() {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public boolean isRequestedSessionIdValid() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromCookie() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromURL() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromUrl() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public void login(String username, String password) throws ServletException {
        // TODO Auto-generated method stub

    }

    @Override
    public void logout() throws ServletException {
        // TODO Auto-generated method stub

    }

    @Override
    public Collection<Part> getParts() throws IOException, ServletException {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public Part getPart(String name) throws IOException, ServletException {
        // TODO Auto-generated method stub
        return null;
    }

    @Override
    public <T extends HttpUpgradeHandler> T upgrade(Class<T> httpUpgradeHandlerClass)
            throws IOException, ServletException {
        // TODO Auto-generated method stub
        return null;
    }

}

然后在异步线程池方法中使用传入自定义的request

import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;

import com.kq.highnet2.framework.base.common.model.MyHttpServletRequest;
import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler;
import org.springframework.aop.interceptor.SimpleAsyncUncaughtExceptionHandler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.AsyncConfigurer;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;


@Configuration
@EnableAsync
public class ThreadConfig  implements AsyncConfigurer {

	public class ContextAwareCallable<T> implements Callable<T> {
		private Callable<T> task;
		private RequestAttributes context;

		public ContextAwareCallable(Callable<T> task, ServletRequestAttributes context) {
			this.task = task;
            //在初始化时就新建request对象,在call方法中request就取不到了
			this.context = new ServletRequestAttributes(new MyHttpServletRequest(context.getRequest()), context.getResponse());;
		}

		@Override
		public T call() throws Exception {
			if (context != null) {//通过这个方法就能替换线程中的request
				RequestContextHolder.setRequestAttributes(context);
			}
			try {
				return task.call();
			} finally {
                //释放request,finally优先级最高,即使有return也会执行
				RequestContextHolder.resetRequestAttributes();
			}
		}
	}

	public class ContextAwarePoolExecutor extends ThreadPoolTaskExecutor {
		@Override
		public <T> Future<T> submit(Callable<T> task) {
			return super.submit(new ContextAwareCallable(task, (ServletRequestAttributes)RequestContextHolder.currentRequestAttributes()));
		}

		@Override
		public <T> ListenableFuture<T> submitListenable(Callable<T> task) {
			return super.submitListenable(new ContextAwareCallable(task, (ServletRequestAttributes)RequestContextHolder.currentRequestAttributes()));
		}
	}
	@Override
	@Bean("Async")
	public Executor getAsyncExecutor() {
		ThreadPoolTaskExecutor executor = new ContextAwarePoolExecutor();
		// 核心线程池数量,方法;
		executor.setCorePoolSize(10);
		// 最大线程数量
		executor.setMaxPoolSize(20);
		// 线程池的队列容量
		executor.setQueueCapacity(1000);
		// 线程名称的前缀
		executor.setThreadNamePrefix("sync-executor-");
		// 线程池对拒绝任务的处理策略
		executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
		executor.initialize();
		return executor;
	}

	@Override
	public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() {
		return new SimpleAsyncUncaughtExceptionHandler();
	}
}

说起来,自定义request,然后通过RequestContextHolder.setRequestAttributes来传入线程的方法,在junit中也能使用,就不用担心代码中使用到了request了。

如果要用同一个线程池通过service方法调用,可以注入Async执行对象


@Service
public class ThreadSyncExecutorServiceImpl implements IThreadSyncExecutorService {
    @Resource(name="Async")
    Executor asyncExecutor;
    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
            throws InterruptedException{
        return ((ThreadConfig.ContextAwarePoolExecutor)asyncExecutor).invokeAll(tasks);
    }
    @Override
    public <T> Future<T> submit(Callable<T> task) {
        return ((ThreadConfig.ContextAwarePoolExecutor)asyncExecutor).submit(task);
    }
    @Override
    public void execute(Runnable runnable){
        ((ThreadConfig.ContextAwarePoolExecutor)asyncExecutor).execute(runnable);
    }
}

好的,下面是一个简单的示例: 假设我们有一个 UserService,其中有一个方法 sendEmail,它需要异步地发送电子邮件。现在我们想要将当前用户的信息透传给异步任务中使用的线程。 首先,我们需要在异步方法上添加 @Async 注解,并在配置类中启用异步支持: ```java @Configuration @EnableAsync public class AppConfig implements AsyncConfigurer { @Override public Executor getAsyncExecutor() { ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); executor.setCorePoolSize(10); executor.setMaxPoolSize(100); executor.setQueueCapacity(10); executor.initialize(); return executor; } @Override public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() { return new CustomAsyncExceptionHandler(); } } ``` 在上面的示例中,我们创建了一个 ThreadPoolTaskExecutor,它将用于执行异步任务。我们还实现了 AsyncConfigurer 接口,并覆盖了 getAsyncExecutor 和 getAsyncUncaughtExceptionHandler 方法,以提供自定义的 Executor 和异常处理程序。 现在我们需要将当前用户信息存储在一个 ThreadLocal 对象中。这可以通过一个拦截器来实现: ```java public class UserContextInterceptor extends HandlerInterceptorAdapter { private final ThreadLocal<String> userThreadLocal = new ThreadLocal<>(); @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { String currentUser = request.getHeader("X-User"); userThreadLocal.set(currentUser); return true; } @Override public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception { userThreadLocal.remove(); } public String getCurrentUser() { return userThreadLocal.get(); } } ``` 在上面的示例中,我们创建了一个 UserContextInterceptor,它将在每个请求的开始和结束时执行。在 preHandle 方法中,我们从请求头中获取当前用户信息,并将其存储在一个 ThreadLocal 对象中。在 afterCompletion 方法中,我们将删除该信息,以避免内存泄漏。 现在,我们可以在 UserService 的 sendEmail 方法中使用 UserContextInterceptor 中存储的当前用户信息: ```java @Service public class UserService { @Autowired private JavaMailSender mailSender; @Autowired private UserContextInterceptor userContextInterceptor; @Async public void sendEmail(String to, String subject, String text) { String currentUser = userContextInterceptor.getCurrentUser(); // 使用当前用户信息发送电子邮件 // ... } } ``` 在上面的示例中,我们使用 @Autowired 注解将 UserContextInterceptor 注入到 UserService 中。在 sendEmail 方法中,我们从 UserContextInterceptor 中获取当前用户信息,并在发送电子邮件时使用它。 通过这种方式,我们可以将当前用户信息透传给异步任务中使用的线程。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值