为什么可以直接@Autowired注入HttpServletRequest,且线程安全

1. 场景

需求:在PrintService类中增加打印当前请求的IP信息

@RestController
public class HelloController {

    private HelloService helloService;

    public HelloController(HelloService helloService) {
        this.helloService = helloService;
    }

    @GetMapping("hello")
    public String hello() {
        helloService.service();
        return "hello";
    }
}

@Service
public class HelloService {

    private PrintService printService;

    public HelloService(PrintService printService) {
        this.printService = printService;
    }

    public void service() {
        printService.print();
    }
}

@Service
public class PrintService {

    public void print() {
        System.out.println("hello, " + UUID.randomUUID());
    }
}

2. 解决方案

2.1 方案1 Controller层注入HttpServletRequest

之前我的做法是,在Controller的方法参数中添加HttpServletRequest参数,可以直接获取到当前请求的HttpServletRequest对象,再层层传入request。

缺点很明显,在传递参数过程中,业务方法增加了不必要的HttpServletRequest参数,Code显得不够优雅!

@RestController
public class HelloController {

    private HelloService helloService;

    public HelloController(HelloService helloService) {
        this.helloService = helloService;
    }

    @GetMapping("hello")
    public String hello(HttpServletRequest request) {
        helloService.service(request);
        return "hello";
    }
}

@Service
public class HelloService {

    private PrintService printService;

    public HelloService(PrintService printService) {
        this.printService = printService;
    }

    public void service(HttpServletRequest request) {
        printService.print(request);
    }
}

@Service
public class PrintService {

    public void print(HttpServletRequest request) {
        System.out.println("IP: " + request.getRemoteHost());
        System.out.println("hello, " + UUID.randomUUID());
    }
}

2.2 方案2 Service直接注入HttpServletRequest(推荐)

现在直接注入HttpServletRequest对象到Service,即可获得当前请求的request信息

在获得相同的效果的同时,也产生了疑惑:

  1. 这两种方式获取到的HttpServletRequest是一样的吗?
  2. 直接注入到成员字段中不会有线程安全的问题吗?
@RestController
public class HelloController {

    private final HelloService helloService;

    public HelloController(HelloService helloService) {
        this.helloService = helloService;
    }

    @GetMapping("hello")
    public String hello() {
        helloService.service();
        return "hello";
    }
}

@Service
public class HelloService {

    private final PrintService printService;

    public HelloService(PrintService printService) {
        this.printService = printService;
    }

    public void service() {
        printService.print();
    }
}

@Service
public class PrintService {
    
    private final HttpServletRequest request;

    public PrintService(HttpServletRequest request) {
        this.request = request;
    }

    public void print() {
        System.out.println("IP: " + request.getRemoteHost());
        System.out.println("hello, " + UUID.randomUUID());
    }
}

3. 比较

3.1 Service内直接注入的Request

内部是一个ObjectFactory

在这里插入图片描述

3.2 Controller方法参数注入的Request

内部是原生的Request

在这里插入图片描述

4. 原理

4.1 Service内直接注入的Request

通过查看注入的Request类型,发现他是一个代理对象,内部是org.springframework.web.context.support.WebApplicationContextUtils.RequestObjectFactory

// 实现了ObjectFactory,是对象工厂
private static class RequestObjectFactory implements ObjectFactory<ServletRequest>, Serializable {

    @Override
    public ServletRequest getObject() {
        return currentRequestAttributes().getRequest();
    }

    @Override
    public String toString() {
        return "Current HttpServletRequest";
    }
}
  1. RequestObjectFactory在哪里注册到Spring容器中的?

通过idea工具,findUsages功能,不难找到RequestObjectFactoryWebApplicationContextUtils#registerWebApplicationScopes中注册了ServletRequestServletResponse等ObjectFactory。

//org.springframework.web.context.support.WebApplicationContextUtils#registerWebApplicationScopes
public static void registerWebApplicationScopes(ConfigurableListableBeanFactory beanFactory,
			@Nullable ServletContext sc) {

    beanFactory.registerScope(WebApplicationContext.SCOPE_REQUEST, new RequestScope());
    beanFactory.registerScope(WebApplicationContext.SCOPE_SESSION, new SessionScope());
    if (sc != null) {
        ServletContextScope appScope = new ServletContextScope(sc);
        beanFactory.registerScope(WebApplicationContext.SCOPE_APPLICATION, appScope);
        // Register as ServletContext attribute, for ContextCleanupListener to detect it.
        sc.setAttribute(ServletContextScope.class.getName(), appScope);
    }
	// 注册了RequestObjectFactory
    beanFactory.registerResolvableDependency(ServletRequest.class, new RequestObjectFactory());
    beanFactory.registerResolvableDependency(ServletResponse.class, new ResponseObjectFactory());
    beanFactory.registerResolvableDependency(HttpSession.class, new SessionObjectFactory());
    beanFactory.registerResolvableDependency(WebRequest.class, new WebRequestObjectFactory());
    if (jsfPresent) {
        FacesDependencyRegistrar.registerFacesDependencies(beanFactory);
    }
}

// org.springframework.beans.factory.support.DefaultListableBeanFactory#registerResolvableDependency
@Override
public void registerResolvableDependency(Class<?> dependencyType, @Nullable Object autowiredValue) {
    Assert.notNull(dependencyType, "Dependency type must not be null");
    if (autowiredValue != null) {
        // resolvableDependencies是DefaultListableBeanFactory的一个Map字段,众所周知DefaultListableBeanFactory是BeanFactory唯一实现
        this.resolvableDependencies.put(dependencyType, autowiredValue);
    }
}
  1. 怎么注入到Service中的?

注册的时候,将ObjectFactory放入了DefaultListableBeanFactory的resolvableDependencies中,只需要看如何使用resolvableDependencies即可

//org.springframework.beans.factory.support.DefaultListableBeanFactory#findAutowireCandidates
protected Map<String, Object> findAutowireCandidates(
			@Nullable String beanName, Class<?> requiredType, DependencyDescriptor descriptor) {

    String[] candidateNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
        this, requiredType, true, descriptor.isEager());
    Map<String, Object> result = CollectionUtils.newLinkedHashMap(candidateNames.length);
    // 遍历resolvableDependencies找到返回ServletRequest的ObjectFactory
    for (Map.Entry<Class<?>, Object> classObjectEntry : this.resolvableDependencies.entrySet()) {
        Class<?> autowiringType = classObjectEntry.getKey();
        if (autowiringType.isAssignableFrom(requiredType)) {
            Object autowiringValue = classObjectEntry.getValue();
            // 获取Request
            autowiringValue = AutowireUtils.resolveAutowiringValue(autowiringValue, requiredType);
            if (requiredType.isInstance(autowiringValue)) {
                result.put(ObjectUtils.identityToString(autowiringValue), autowiringValue);
                break;
            }
        }
    }
    // 省略...
    return result;
}
// org.springframework.beans.factory.support.AutowireUtils#resolveAutowiringValue
public static Object resolveAutowiringValue(Object autowiringValue, Class<?> requiredType) {
    if (autowiringValue instanceof ObjectFactory<?> factory && !requiredType.isInstance(autowiringValue)) {
        if (autowiringValue instanceof Serializable && requiredType.isInterface()) {
            autowiringValue = Proxy.newProxyInstance(requiredType.getClassLoader(),
                                                     new Class<?>[] {requiredType}, new ObjectFactoryDelegatingInvocationHandler(factory));
        }
        else {
            // 调用 ObjectFactory.getObject();获取Request
            return factory.getObject();
        }
    }
    return autowiringValue;
}
  1. 怎么保证线程安全的?

RequestObjectFactory的getObject方法

private static class RequestObjectFactory implements ObjectFactory<ServletRequest>, Serializable {
    @Override
    public ServletRequest getObject() {
        return currentRequestAttributes().getRequest();
    }
}
// org.springframework.web.context.support.WebApplicationContextUtils#currentRequestAttributes
private static ServletRequestAttributes currentRequestAttributes() {
    RequestAttributes requestAttr = RequestContextHolder.currentRequestAttributes();

    return servletRequestAttributes;
}

// org.springframework.web.context.request.RequestContextHolder#currentRequestAttributes
public static RequestAttributes currentRequestAttributes() throws IllegalStateException {
    RequestAttributes attributes = getRequestAttributes();
 
    return attributes;
}
// org.springframework.web.context.request.RequestContextHolder#getRequestAttributes
@Nullable
public static RequestAttributes getRequestAttributes() {
    // requestAttributesHolder是ThreadLocal,保证线程安全的
    RequestAttributes attributes = requestAttributesHolder.get();
    if (attributes == null) {
        attributes = inheritableRequestAttributesHolder.get();
    }
    return attributes;
}
  1. requestAttributesHolder中的request数据源头在哪里?

ThreadLocal通过set设置数据,RequestContextHolder#setRequestAttributes中调用了ThreadLocal.set方法,通过方法栈回溯找到RequestContextFilter过滤器,在其内部创建了ServletRequestAttributes,并保存在ThreadLocal中

// org.springframework.web.context.request.RequestContextHolder#setRequestAttributes
public static void setRequestAttributes(@Nullable RequestAttributes attributes, boolean inheritable) {
    if (attributes == null) {
        resetRequestAttributes();
    }
    else {
        if (inheritable) {
            inheritableRequestAttributesHolder.set(attributes);
            requestAttributesHolder.remove();
        }
        else {
            requestAttributesHolder.set(attributes);
            inheritableRequestAttributesHolder.remove();
        }
    }
}

在这里插入图片描述

4.2 Controller方法参数注入的Request

通过在controller方法中打断点的方式,回溯方法栈分析request的源头

在这里插入图片描述

//org.springframework.web.method.support.InvocableHandlerMethod#getMethodArgumentValues
protected Object[] getMethodArgumentValues(NativeWebRequest request, @Nullable ModelAndViewContainer mavContainer,
			Object... providedArgs) throws Exception {

    MethodParameter[] parameters = getMethodParameters();
    if (ObjectUtils.isEmpty(parameters)) {
        return EMPTY_ARGS;
    }

    Object[] args = new Object[parameters.length];
    for (int i = 0; i < parameters.length; i++) {
        MethodParameter parameter = parameters[i];
        parameter.initParameterNameDiscovery(this.parameterNameDiscoverer);
        args[i] = findProvidedArgument(parameter, providedArgs);
        if (args[i] != null) {
            continue;
        }
        if (!this.resolvers.supportsParameter(parameter)) {
            throw new IllegalStateException(formatArgumentError(parameter, "No suitable resolver"));
        }
        try {
            // 走到这里通过参数解析器解析参数
            args[i] = this.resolvers.resolveArgument(parameter, mavContainer, request, this.dataBinderFactory);
        }
        catch (Exception ex) {
            throw ex;
        }
    }
    return args;
}
//org.springframework.web.method.support.HandlerMethodArgumentResolverComposite#resolveArgument
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
			NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {

    // 内部遍历argumentResolvers获取得到ServletRequestMethodArgumentResolver
    HandlerMethodArgumentResolver resolver = getArgumentResolver(parameter);

    return resolver.resolveArgument(parameter, mavContainer, webRequest, binderFactory);
}
// org.springframework.web.servlet.mvc.method.annotation.ServletRequestMethodArgumentResolver#resolveArgument
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {
    Class<?> paramType = parameter.getParameterType();
    if (WebRequest.class.isAssignableFrom(paramType)) {
        if (!paramType.isInstance(webRequest)) {
            String var10002 = paramType.getName();
            throw new IllegalStateException("Current request is not of type [" + var10002 + "]: " + webRequest);
        } else {
            return webRequest;
        }
    } else {
        // 有个三元表达式,走后者this.resolveNativeRequest(webRequest, paramType)
        return !ServletRequest.class.isAssignableFrom(paramType) && !MultipartRequest.class.isAssignableFrom(paramType) ? this.resolveArgument(paramType, (HttpServletRequest)this.resolveNativeRequest(webRequest, HttpServletRequest.class)) : this.resolveNativeRequest(webRequest, paramType);
    }
}
// org.springframework.web.servlet.mvc.method.annotation.ServletRequestMethodArgumentResolver#resolveNativeRequest
private <T> T resolveNativeRequest(NativeWebRequest webRequest, Class<T> requiredType) {
    // webRequest就是 ServletWebRequest,内部封装了原生的Request
    T nativeRequest = webRequest.getNativeRequest(requiredType);
    if (nativeRequest == null) {
        String var10002 = requiredType.getName();
        throw new IllegalStateException("Current request is not of type [" + var10002 + "]: " + webRequest);
    } else {
        return nativeRequest;
    }
}

// org.springframework.web.context.request.ServletWebRequest#getNativeRequest
public <T> T getNativeRequest(@Nullable Class<T> requiredType) {
    // getRequest() 就返回了原生的Request
    return WebUtils.getNativeRequest(getRequest(), requiredType);
}

问题:ServletWebRequest对象是如何创建的呢?

回过头来才发现getMethodArgumentValues方法的参数中已经把ServletWebRequest传递过来了,那么他是怎么创建的呢?再次根据方法栈回溯。。。

结论:在HandlerAdapter中创建的

//org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#invokeHandlerMethod
@Nullable
protected ModelAndView invokeHandlerMethod(HttpServletRequest request, HttpServletResponse response, HandlerMethod handlerMethod) throws Exception {

    // 生成ServletWebRequest,内部封装了原生的request和response
    AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
    asyncWebRequest.setTimeout(this.asyncRequestTimeout);
    // 省略code....
    ServletWebRequest webRequest = asyncWebRequest instanceof ServletWebRequest ? (ServletWebRequest)asyncWebRequest : new ServletWebRequest(request, response);
    // 省略code....

    // invokeAndHandle方法最终会把ServletWebRequest对象,传递到getMethodArgumentValues方法参数中
    invocableMethod.invokeAndHandle(webRequest, mavContainer, new Object[0]);
    return asyncManager.isConcurrentHandlingStarted() ? null : this.getModelAndView(mavContainer, modelFactory, webRequest);
}
// org.springframework.web.context.request.async.WebAsyncUtils#createAsyncWebRequest
public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
    AsyncWebRequest prev = getAsyncManager(request).getAsyncWebRequest();
    return (prev instanceof StandardServletAsyncWebRequest standardRequest ?
            new StandardServletAsyncWebRequest(request, response, standardRequest) :
            new StandardServletAsyncWebRequest(request, response));
}

5. 结论

  1. Controller类的方法参数方式注入的HttpServletRequest是ServletDispatcher通过方法调用的方式传递给Controller的,并不是通过Spring容器注入的方式获取的
  2. Service类中直接通过Spring容器注入HttpServletRequest,底层是通过ObjectFactory、ThreadLocal和OncePerRequestFilter实现的,其中ThreadLocal保证了线程安全

6. 参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值