问题
朋友遇到一个问题:他想在Service方法中使用HttpServletRequest
的API,但是又不想把HttpServletRequest
对象当作这个Service方法的参数传过来,原因是这个方法被N多Controller调用,加一个参数就得改一堆代码。一句话:就是他懒。不过,这个问题该这么解决呢?
思考
不把HttpServletRequest
当作参数传过来,这意味着要在Service的方法中直接获取到HttpServletRequest
对象。
我们知道,一次请求,Web应用服务器就会分配一个线程去处理。也就是说,在Service方法中获取到的HttpServletRequest
对象需要满足:线程内共享,线程间隔离。
这恰恰是ThreadLocal
的应用场景。
思路
那么,就需要在请求执行之前获取到HttpServletRequest
,把它set()
到某个类的ThreadLocal
类型的静态成员中,使用的时候直接通过静态方式访问到这个ThreadLocal
对象,调用它的get()
方法,即可获取到线程隔离的HttpServletRequest
了。最后,在请求结束后,要调用ThreadLocal
的remove()
方法,清理资源引用。
实现
方式一 利用ServletRequestListener实现
import javax.servlet.ServletRequestEvent;
import javax.servlet.ServletRequestListener;
import javax.servlet.http.HttpServletRequest;
public class RequestHolder implements ServletRequestListener {
private static ThreadLocal<HttpServletRequest> httpServletRequestHolder =
new ThreadLocal<HttpServletRequest>();
@Override
public void requestInitialized(ServletRequestEvent requestEvent) {
HttpServletRequest request = (HttpServletRequest) requestEvent.getServletRequest();
httpServletRequestHolder.set(request); // 绑定到当前线程
}
@Override
public void requestDestroyed(ServletRequestEvent requestEvent) {
httpServletRequestHolder.remove(); // 清理资源引用
}
public static HttpServletRequest getHttpServletRequest() {
return httpServletRequestHolder.get();
}
}
方式二 利用Filter实现
import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
public class RequestHolder implements Filter {
private static ThreadLocal<HttpServletRequest> httpServletRequestHolder =
new ThreadLocal<HttpServletRequest>();
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
httpServletRequestHolder.set((HttpServletRequest) request); // 绑定到当前线程
try {
chain.doFilter(request, response);
} catch (Exception e) {
throw e;
} finally {
httpServletRequestHolder.remove(); // 清理资源引用
}
}
@Override
public void destroy() {
}
public static HttpServletRequest getHttpServletRequest() {
return httpServletRequestHolder.get();
}
}
方式三 利用SpringMVC的拦截器实现
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
public class RequestHolder extends HandlerInterceptorAdapter {
private static ThreadLocal<HttpServletRequest> httpServletRequestHolder =
new ThreadLocal<HttpServletRequest>();
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
httpServletRequestHolder.set(request); // 绑定到当前线程
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex)
throws Exception {
httpServletRequestHolder.remove(); // 清理资源引用
}
public static HttpServletRequest getHttpServletRequest() {
return httpServletRequestHolder.get();
}
}
调用
无论是哪种方式,都可以直接在Service的方法中执行
HttpServletRequest request = RequestHolder.getHttpServletRequest();
即可直接获取到线程隔离的HttpServletRequest
了。
延伸
类似的功能,在SpringMVC中就有开箱即用的实现。代码是
HttpServletRequest request =
((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
那么SpringMVC是如何实现的呢?
先看一下RequestContextHolder
的源码(精简了一下)
public abstract class RequestContextHolder {
private static final ThreadLocal<RequestAttributes> requestAttributesHolder =
new NamedThreadLocal<RequestAttributes>("Request attributes"); // 重点
private static final ThreadLocal<RequestAttributes> inheritableRequestAttributesHolder =
new NamedInheritableThreadLocal<RequestAttributes>("Request context");
public static void resetRequestAttributes() {
requestAttributesHolder.remove(); // 重点
inheritableRequestAttributesHolder.remove();
}
public static void setRequestAttributes(RequestAttributes attributes) {
setRequestAttributes(attributes, false);
}
public static void setRequestAttributes(RequestAttributes attributes, boolean inheritable) {
if (attributes == null) {
resetRequestAttributes();
}
else {
if (inheritable) {
inheritableRequestAttributesHolder.set(attributes);
requestAttributesHolder.remove();
}
else {
requestAttributesHolder.set(attributes); // 重点
inheritableRequestAttributesHolder.remove();
}
}
}
public static RequestAttributes getRequestAttributes() {
RequestAttributes attributes = requestAttributesHolder.get(); // 重点
if (attributes == null) {
attributes = inheritableRequestAttributesHolder.get();
}
return attributes;
}
}
主要代码就是把RequestAttributes
对象ThreadLocal
化,然后提供了setRequestAttributes()
、getRequestAttributes()
等静态方法,来放入或取出ThreadLocal
中线程隔离的RequestAttributes
。
接下来看一下setRequestAttributes()
方法是在什么时候调用的呢?
可以看到setRequestAttributes()
被initContextHolders()
调用,initContextHolders()
又被processRequest()
调用,而processRequest()
在每次请求时都会被调用,无论是GET、POST、PUT、DELETE还是TRACE、OPTIONS等等。
先来看一下processRequest()
方法
protected final void processRequest(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
long startTime = System.currentTimeMillis();
Throwable failureCause = null;
LocaleContext previousLocaleContext = LocaleContextHolder.getLocaleContext();
LocaleContext localeContext = buildLocaleContext(request);
RequestAttributes previousAttributes = RequestContextHolder.getRequestAttributes(); // 重点1
ServletRequestAttributes requestAttributes =
buildRequestAttributes(request, response, previousAttributes); // 重点2
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
asyncManager.registerCallableInterceptor(FrameworkServlet.class.getName(), new RequestBindingInterceptor());
initContextHolders(request, localeContext, requestAttributes); // 重点3
try {
doService(request, response); // 执行请求
}
catch (ServletException ex) {
failureCause = ex;
throw ex;
}
catch (IOException ex) {
failureCause = ex;
throw ex;
}
catch (Throwable ex) {
failureCause = ex;
throw new NestedServletException("Request processing failed", ex);
}
finally {
resetContextHolders(request, previousLocaleContext, previousAttributes); // 重点4
if (requestAttributes != null) {
requestAttributes.requestCompleted();
}
if (logger.isDebugEnabled()) {
if (failureCause != null) {
this.logger.debug("Could not complete request", failureCause);
}
else {
if (asyncManager.isConcurrentHandlingStarted()) {
logger.debug("Leaving response open for concurrent processing");
}
else {
this.logger.debug("Successfully completed request");
}
}
}
publishRequestHandledEvent(request, startTime, failureCause); // 发布请求处理完成事件
}
}
重点1
在set
之前就先get
,通常为null
。
重点2
直接看buildRequestAttributes()
方法的实现
protected ServletRequestAttributes buildRequestAttributes(HttpServletRequest request, HttpServletResponse response,
RequestAttributes previousAttributes) {
if (previousAttributes == null || previousAttributes instanceof ServletRequestAttributes) {
return new ServletRequestAttributes(request); // 重点
}
else {
return null; // preserve the pre-bound RequestAttributes instance
}
}
ServletRequestAttributes
的代码不再去看了,它就是RequestAttributes
接口的实现类,只是对HttpServletRequest
对象(还有HttpSession
)的一个包装。
重点3
直接看initContextHolders()
方法的实现
private void initContextHolders(HttpServletRequest request, LocaleContext localeContext,
RequestAttributes requestAttributes) {
if (localeContext != null) {
LocaleContextHolder.setLocaleContext(localeContext, this.threadContextInheritable);
}
if (requestAttributes != null) {
RequestContextHolder.setRequestAttributes(requestAttributes, this.threadContextInheritable); // 重点
}
if (logger.isTraceEnabled()) {
logger.trace("Bound request context to thread: " + request);
}
}
调用RequestContextHolder.setRequestAttributes()
方法,把requestAttributes
对象放入。this.threadContextInheritable
默认是false
。
即把HttpServletRequest
的封装对象ServletRequestAttributes
与当前线程绑定。
重点4
private void resetContextHolders(HttpServletRequest request, LocaleContext prevLocaleContext,
RequestAttributes previousAttributes) {
LocaleContextHolder.setLocaleContext(prevLocaleContext, this.threadContextInheritable);
RequestContextHolder.setRequestAttributes(previousAttributes, this.threadContextInheritable); // 重点
if (logger.isTraceEnabled()) {
logger.trace("Cleared thread-bound request context: " + request);
}
}
在请求执行完毕后,再次调用RequestContextHolder.setRequestAttributes()
,但由于previousAttributes
为null
,所以,这里相当于调用RequestContextHolder.setRequestAttributes(null, false)
。
再回顾一下setRequestAttributes()
方法。
public static void setRequestAttributes(RequestAttributes attributes, boolean inheritable) {
if (attributes == null) {
resetRequestAttributes();
}
else {
if (inheritable) {
inheritableRequestAttributesHolder.set(attributes);
requestAttributesHolder.remove();
}
else {
requestAttributesHolder.set(attributes);
inheritableRequestAttributesHolder.remove();
}
}
}
参数attributes
为null
,就会调用resetRequestAttributes()
,来清理当前线程引用的RequestAttributes
。
至此,SpringMVC是如何实现直接获取HttpServletRequest
对象的源码,就分析完了。和我们自己实现的思路差不多,只不过多绕了几个弯而已。