1. 场景
需求:在
PrintService
类中增加打印当前请求的IP信息
@RestController
public class HelloController {
private HelloService helloService;
public HelloController(HelloService helloService) {
this.helloService = helloService;
}
@GetMapping("hello")
public String hello() {
helloService.service();
return "hello";
}
}
@Service
public class HelloService {
private PrintService printService;
public HelloService(PrintService printService) {
this.printService = printService;
}
public void service() {
printService.print();
}
}
@Service
public class PrintService {
public void print() {
System.out.println("hello, " + UUID.randomUUID());
}
}
2. 解决方案
2.1 方案1 Controller层注入HttpServletRequest
之前我的做法是,在Controller的方法参数中添加
HttpServletRequest
参数,可以直接获取到当前请求的HttpServletRequest
对象,再层层传入request。缺点很明显,在传递参数过程中,业务方法增加了不必要的
HttpServletRequest
参数,Code显得不够优雅!
@RestController
public class HelloController {
private HelloService helloService;
public HelloController(HelloService helloService) {
this.helloService = helloService;
}
@GetMapping("hello")
public String hello(HttpServletRequest request) {
helloService.service(request);
return "hello";
}
}
@Service
public class HelloService {
private PrintService printService;
public HelloService(PrintService printService) {
this.printService = printService;
}
public void service(HttpServletRequest request) {
printService.print(request);
}
}
@Service
public class PrintService {
public void print(HttpServletRequest request) {
System.out.println("IP: " + request.getRemoteHost());
System.out.println("hello, " + UUID.randomUUID());
}
}
2.2 方案2 Service直接注入HttpServletRequest(推荐)
现在直接注入
HttpServletRequest
对象到Service,即可获得当前请求的request信息在获得相同的效果的同时,也产生了疑惑:
- 这两种方式获取到的HttpServletRequest是一样的吗?
- 直接注入到成员字段中不会有线程安全的问题吗?
@RestController
public class HelloController {
private final HelloService helloService;
public HelloController(HelloService helloService) {
this.helloService = helloService;
}
@GetMapping("hello")
public String hello() {
helloService.service();
return "hello";
}
}
@Service
public class HelloService {
private final PrintService printService;
public HelloService(PrintService printService) {
this.printService = printService;
}
public void service() {
printService.print();
}
}
@Service
public class PrintService {
private final HttpServletRequest request;
public PrintService(HttpServletRequest request) {
this.request = request;
}
public void print() {
System.out.println("IP: " + request.getRemoteHost());
System.out.println("hello, " + UUID.randomUUID());
}
}
3. 比较
3.1 Service内直接注入的Request
内部是一个
ObjectFactory
3.2 Controller方法参数注入的Request
内部是原生的
Request
4. 原理
4.1 Service内直接注入的Request
通过查看注入的Request类型,发现他是一个代理对象,内部是
org.springframework.web.context.support.WebApplicationContextUtils.RequestObjectFactory
// 实现了ObjectFactory,是对象工厂
private static class RequestObjectFactory implements ObjectFactory<ServletRequest>, Serializable {
@Override
public ServletRequest getObject() {
return currentRequestAttributes().getRequest();
}
@Override
public String toString() {
return "Current HttpServletRequest";
}
}
RequestObjectFactory
在哪里注册到Spring容器中的?
通过idea工具,findUsages功能,不难找到
RequestObjectFactory
在WebApplicationContextUtils#registerWebApplicationScopes
中注册了ServletRequest
、ServletResponse
等ObjectFactory。
//org.springframework.web.context.support.WebApplicationContextUtils#registerWebApplicationScopes
public static void registerWebApplicationScopes(ConfigurableListableBeanFactory beanFactory,
@Nullable ServletContext sc) {
beanFactory.registerScope(WebApplicationContext.SCOPE_REQUEST, new RequestScope());
beanFactory.registerScope(WebApplicationContext.SCOPE_SESSION, new SessionScope());
if (sc != null) {
ServletContextScope appScope = new ServletContextScope(sc);
beanFactory.registerScope(WebApplicationContext.SCOPE_APPLICATION, appScope);
// Register as ServletContext attribute, for ContextCleanupListener to detect it.
sc.setAttribute(ServletContextScope.class.getName(), appScope);
}
// 注册了RequestObjectFactory
beanFactory.registerResolvableDependency(ServletRequest.class, new RequestObjectFactory());
beanFactory.registerResolvableDependency(ServletResponse.class, new ResponseObjectFactory());
beanFactory.registerResolvableDependency(HttpSession.class, new SessionObjectFactory());
beanFactory.registerResolvableDependency(WebRequest.class, new WebRequestObjectFactory());
if (jsfPresent) {
FacesDependencyRegistrar.registerFacesDependencies(beanFactory);
}
}
// org.springframework.beans.factory.support.DefaultListableBeanFactory#registerResolvableDependency
@Override
public void registerResolvableDependency(Class<?> dependencyType, @Nullable Object autowiredValue) {
Assert.notNull(dependencyType, "Dependency type must not be null");
if (autowiredValue != null) {
// resolvableDependencies是DefaultListableBeanFactory的一个Map字段,众所周知DefaultListableBeanFactory是BeanFactory唯一实现
this.resolvableDependencies.put(dependencyType, autowiredValue);
}
}
- 怎么注入到Service中的?
注册的时候,将ObjectFactory放入了DefaultListableBeanFactory的resolvableDependencies中,只需要看如何使用resolvableDependencies即可
//org.springframework.beans.factory.support.DefaultListableBeanFactory#findAutowireCandidates
protected Map<String, Object> findAutowireCandidates(
@Nullable String beanName, Class<?> requiredType, DependencyDescriptor descriptor) {
String[] candidateNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
this, requiredType, true, descriptor.isEager());
Map<String, Object> result = CollectionUtils.newLinkedHashMap(candidateNames.length);
// 遍历resolvableDependencies找到返回ServletRequest的ObjectFactory
for (Map.Entry<Class<?>, Object> classObjectEntry : this.resolvableDependencies.entrySet()) {
Class<?> autowiringType = classObjectEntry.getKey();
if (autowiringType.isAssignableFrom(requiredType)) {
Object autowiringValue = classObjectEntry.getValue();
// 获取Request
autowiringValue = AutowireUtils.resolveAutowiringValue(autowiringValue, requiredType);
if (requiredType.isInstance(autowiringValue)) {
result.put(ObjectUtils.identityToString(autowiringValue), autowiringValue);
break;
}
}
}
// 省略...
return result;
}
// org.springframework.beans.factory.support.AutowireUtils#resolveAutowiringValue
public static Object resolveAutowiringValue(Object autowiringValue, Class<?> requiredType) {
if (autowiringValue instanceof ObjectFactory<?> factory && !requiredType.isInstance(autowiringValue)) {
if (autowiringValue instanceof Serializable && requiredType.isInterface()) {
autowiringValue = Proxy.newProxyInstance(requiredType.getClassLoader(),
new Class<?>[] {requiredType}, new ObjectFactoryDelegatingInvocationHandler(factory));
}
else {
// 调用 ObjectFactory.getObject();获取Request
return factory.getObject();
}
}
return autowiringValue;
}
- 怎么保证线程安全的?
RequestObjectFactory的getObject方法
private static class RequestObjectFactory implements ObjectFactory<ServletRequest>, Serializable {
@Override
public ServletRequest getObject() {
return currentRequestAttributes().getRequest();
}
}
// org.springframework.web.context.support.WebApplicationContextUtils#currentRequestAttributes
private static ServletRequestAttributes currentRequestAttributes() {
RequestAttributes requestAttr = RequestContextHolder.currentRequestAttributes();
return servletRequestAttributes;
}
// org.springframework.web.context.request.RequestContextHolder#currentRequestAttributes
public static RequestAttributes currentRequestAttributes() throws IllegalStateException {
RequestAttributes attributes = getRequestAttributes();
return attributes;
}
// org.springframework.web.context.request.RequestContextHolder#getRequestAttributes
@Nullable
public static RequestAttributes getRequestAttributes() {
// requestAttributesHolder是ThreadLocal,保证线程安全的
RequestAttributes attributes = requestAttributesHolder.get();
if (attributes == null) {
attributes = inheritableRequestAttributesHolder.get();
}
return attributes;
}
- requestAttributesHolder中的request数据源头在哪里?
ThreadLocal通过set设置数据,RequestContextHolder#setRequestAttributes中调用了ThreadLocal.set方法,通过方法栈回溯找到
RequestContextFilter
过滤器,在其内部创建了ServletRequestAttributes
,并保存在ThreadLocal中
// org.springframework.web.context.request.RequestContextHolder#setRequestAttributes
public static void setRequestAttributes(@Nullable RequestAttributes attributes, boolean inheritable) {
if (attributes == null) {
resetRequestAttributes();
}
else {
if (inheritable) {
inheritableRequestAttributesHolder.set(attributes);
requestAttributesHolder.remove();
}
else {
requestAttributesHolder.set(attributes);
inheritableRequestAttributesHolder.remove();
}
}
}
4.2 Controller方法参数注入的Request
通过在controller方法中打断点的方式,回溯方法栈分析request的源头
//org.springframework.web.method.support.InvocableHandlerMethod#getMethodArgumentValues
protected Object[] getMethodArgumentValues(NativeWebRequest request, @Nullable ModelAndViewContainer mavContainer,
Object... providedArgs) throws Exception {
MethodParameter[] parameters = getMethodParameters();
if (ObjectUtils.isEmpty(parameters)) {
return EMPTY_ARGS;
}
Object[] args = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) {
MethodParameter parameter = parameters[i];
parameter.initParameterNameDiscovery(this.parameterNameDiscoverer);
args[i] = findProvidedArgument(parameter, providedArgs);
if (args[i] != null) {
continue;
}
if (!this.resolvers.supportsParameter(parameter)) {
throw new IllegalStateException(formatArgumentError(parameter, "No suitable resolver"));
}
try {
// 走到这里通过参数解析器解析参数
args[i] = this.resolvers.resolveArgument(parameter, mavContainer, request, this.dataBinderFactory);
}
catch (Exception ex) {
throw ex;
}
}
return args;
}
//org.springframework.web.method.support.HandlerMethodArgumentResolverComposite#resolveArgument
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {
// 内部遍历argumentResolvers获取得到ServletRequestMethodArgumentResolver
HandlerMethodArgumentResolver resolver = getArgumentResolver(parameter);
return resolver.resolveArgument(parameter, mavContainer, webRequest, binderFactory);
}
// org.springframework.web.servlet.mvc.method.annotation.ServletRequestMethodArgumentResolver#resolveArgument
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {
Class<?> paramType = parameter.getParameterType();
if (WebRequest.class.isAssignableFrom(paramType)) {
if (!paramType.isInstance(webRequest)) {
String var10002 = paramType.getName();
throw new IllegalStateException("Current request is not of type [" + var10002 + "]: " + webRequest);
} else {
return webRequest;
}
} else {
// 有个三元表达式,走后者this.resolveNativeRequest(webRequest, paramType)
return !ServletRequest.class.isAssignableFrom(paramType) && !MultipartRequest.class.isAssignableFrom(paramType) ? this.resolveArgument(paramType, (HttpServletRequest)this.resolveNativeRequest(webRequest, HttpServletRequest.class)) : this.resolveNativeRequest(webRequest, paramType);
}
}
// org.springframework.web.servlet.mvc.method.annotation.ServletRequestMethodArgumentResolver#resolveNativeRequest
private <T> T resolveNativeRequest(NativeWebRequest webRequest, Class<T> requiredType) {
// webRequest就是 ServletWebRequest,内部封装了原生的Request
T nativeRequest = webRequest.getNativeRequest(requiredType);
if (nativeRequest == null) {
String var10002 = requiredType.getName();
throw new IllegalStateException("Current request is not of type [" + var10002 + "]: " + webRequest);
} else {
return nativeRequest;
}
}
// org.springframework.web.context.request.ServletWebRequest#getNativeRequest
public <T> T getNativeRequest(@Nullable Class<T> requiredType) {
// getRequest() 就返回了原生的Request
return WebUtils.getNativeRequest(getRequest(), requiredType);
}
问题:ServletWebRequest对象是如何创建的呢?
回过头来才发现
getMethodArgumentValues
方法的参数中已经把ServletWebRequest
传递过来了,那么他是怎么创建的呢?再次根据方法栈回溯。。。结论:在HandlerAdapter中创建的
//org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#invokeHandlerMethod
@Nullable
protected ModelAndView invokeHandlerMethod(HttpServletRequest request, HttpServletResponse response, HandlerMethod handlerMethod) throws Exception {
// 生成ServletWebRequest,内部封装了原生的request和response
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
asyncWebRequest.setTimeout(this.asyncRequestTimeout);
// 省略code....
ServletWebRequest webRequest = asyncWebRequest instanceof ServletWebRequest ? (ServletWebRequest)asyncWebRequest : new ServletWebRequest(request, response);
// 省略code....
// invokeAndHandle方法最终会把ServletWebRequest对象,传递到getMethodArgumentValues方法参数中
invocableMethod.invokeAndHandle(webRequest, mavContainer, new Object[0]);
return asyncManager.isConcurrentHandlingStarted() ? null : this.getModelAndView(mavContainer, modelFactory, webRequest);
}
// org.springframework.web.context.request.async.WebAsyncUtils#createAsyncWebRequest
public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
AsyncWebRequest prev = getAsyncManager(request).getAsyncWebRequest();
return (prev instanceof StandardServletAsyncWebRequest standardRequest ?
new StandardServletAsyncWebRequest(request, response, standardRequest) :
new StandardServletAsyncWebRequest(request, response));
}
5. 结论
- Controller类的方法参数方式注入的
HttpServletRequest
是ServletDispatcher通过方法调用的方式传递给Controller的,并不是通过Spring容器注入的方式获取的 - Service类中直接通过Spring容器注入
HttpServletRequest
,底层是通过ObjectFactory、ThreadLocal和OncePerRequestFilter实现的,其中ThreadLocal保证了线程安全