oauth2关于websocket携带token的探讨
一、简述
前段时间,公司有个技术需求是做一个实时在线沟通功能,在遵循**合适、简单、演化**的原则下,我决定采用websocket技术。在此之前我对于websocket的了解仅限于《深入浅出springboot2.x》这本书里面很局限的知识,以及B站上某个一小时的视频。但实际上和项目进行整合时,仍然要解决各种问题。如、oauth2 token携带、依赖注入、对象入参出参等问题。针对这些问题,寻求对应的解决方案。
二、关于websocket请求携带token
原有项目采用的认证安全框架是spring security + oauth2,对于大部分请求,是需要在请求头携带token来实现的。同样的websocket只要携带了对应的authorization=Bearer +token也是可以正常建立连接的。但是比较坑的点就是通过vue js的方式没法指定对应的请求头的key(但通过postman是可以的),也就是意味着websocket没法这样子携带authorization=Bearer +token来建立连接了。针对这个问题,我提供了以下几种解决方案的思路。
2.1、通过websocket下的子协议来实现
可以尝试切换啊至Stomp这个协议来实现,前端采用SocketJs框架来实现对应定制请求头。实现携带authorization=Bearer +token 的需求,这样就可以正常建立连接
2.2、资源服务器放开请求路径。
针对websocket的请求,我们可以在oauth2的资源服务器选择放开此类请求,但是这样就意味着你任何websocket请求,不用携带任何token就可以访问,这又有什么意义呢?当然如果您不在乎的话,也不是不可以。
.antMatchers("/login", "/websocket/**")
.permitAll()
2.3、请求参数上携带access_token=token
我们虽然不能自定义请求头的key,但是我们可以自定义请求参数,只要携带了access_token=token,也能通过oauth2的授权。我们将进行其核心源码的分析,看一下oath2是如何将用户信息加载但security上下文的。
BearerTokenExtractor.java
public class BearerTokenExtractor implements TokenExtractor {
private final static Log logger = LogFactory.getLog(BearerTokenExtractor.class);
@Override
public Authentication extract(HttpServletRequest request) {
/**获取对应的token**/
String tokenValue = extractToken(request);
if (tokenValue != null) {
PreAuthenticatedAuthenticationToken authentication = new PreAuthenticatedAuthenticationToken(tokenValue, "");
return authentication;
}
return null;
}
/**
这个方法的本质就是获取token 先去请求头中获取token,如果为null,再去请求头的
access_token中获取, OAuth2AccessToken.ACCESS_TOKEN就是access_token
**/
protected String extractToken(HttpServletRequest request) {
// first check the header...
String token = extractHeaderToken(request);
// bearer type allows a request parameter as well
if (token == null) {
logger.debug("Token not found in headers. Trying request parameters.");
token = request.getParameter(OAuth2AccessToken.ACCESS_TOKEN);
if (token == null) {
logger.debug("Token not found in request parameters. Not an OAuth2 request.");
}
else {
request.setAttribute(OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, OAuth2AccessToken.BEARER_TYPE);
}
}
return token;
}
/**
* Extract the OAuth bearer token from a header.
* 从请求头中获取对应的token
* @param request The request.
* @return The token, or null if no OAuth authorization header was supplied.
*/
protected String extractHeaderToken(HttpServletRequest request) {
Enumeration<String> headers = request.getHeaders("Authorization");
while (headers.hasMoreElements()) { // typically there is only one (most servers enforce that)
String value = headers.nextElement();
if ((value.toLowerCase().startsWith(OAuth2AccessToken.BEARER_TYPE.toLowerCase()))) {
String authHeaderValue = value.substring(OAuth2AccessToken.BEARER_TYPE.length()).trim();
// Add this here for the auth details later. Would be better to change the signature of this method.
request.setAttribute(OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE,
value.substring(0, OAuth2AccessToken.BEARER_TYPE.length()).trim());
int commaIndex = authHeaderValue.indexOf(',');
if (commaIndex > 0) {
authHeaderValue = authHeaderValue.substring(0, commaIndex);
}
return authHeaderValue;
}
}
return null;
}
}
简单分析一下,oauth2中获取token,并且将其加入到springsecurity的上下文的过程,本质是先从请求头上线获取,如果获取不到,再去请求参数上获取。所以只要将token放在对应的请求头上就行了。
2.4、请求websocket的请求头中携带sec-websocket-protocol=Bearer +token
对于有些强迫症比较严重的开发人员,觉得我就是不想在请求路径上携带token,就是想在请求头上携带token,那么接下来这种方式可能能够暂时满足你,虽然前端vue不能自定义请求头key,但是websocket允许请求头里面携带一个key为,sec-websocket-protocol的参数。那么我们可以做以下几步来实现。
-
前端人员在请求头上携带sec-websocket-protocol=Bearer +token
-
后台在请求到达oauth2之前进行拦截,然后将在请求头上添加Authorization=Bearer +token(key首字母大写),然后在响应头(respone)上添加sec-websocket-protocol=Bearer +token(不添加会报错)。
这样我们就完成了请求头上携带token的方法。相应的代码如下。
RequestReplaceFilter.java
package com.linkyoyo.bill.web;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotNull;
import java.io.BufferedReader;
import java.io.IOException;
@Order(Ordered.HIGHEST_PRECEDENCE)
@Component
@Slf4j
public class RequestReplaceFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException {
String contentType = request.getContentType();
request = this.addTokenForWebSocket(request, response);
filterChain.doFilter(request, response);
}
private HttpServletRequest addTokenForWebSocket(HttpServletRequest request, HttpServletResponse response) { ;
String token = request.getHeader("authorization");
if(StrUtil.isNotBlank(token)) {
return request;
}
HeaderMapRequestWrapper requestWrapper = new HeaderMapRequestWrapper(request);
token = request.getHeader("sec-websocket-protocol");
if(StrUtil.isBlank(token)) {
return request;
}
requestWrapper.addHeader("Authorization", token);
response.addHeader("sec-websocket-protocol", token);
return requestWrapper;
}
}
HeaderMapRequestWrapper.java 这个是从网上复制下来的。
package com.linkyoyo.bill.web;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.*;
/**
* @author 罗富晓 [295006967@qq.com]
* @date 2022/2/21 15:24
*/
public class HeaderMapRequestWrapper extends HttpServletRequestWrapper {
/**
* Constructs a request object wrapping the given request.
*
* @param request The request to wrap
* @throws IllegalArgumentException if the request is null
*/
public HeaderMapRequestWrapper(HttpServletRequest request) {
super(request);
}
private Map<String, String> headerMap = new HashMap();
public void addHeader(String name, String value) {
headerMap.put(name, value);
}
public void removeHeader(String name) {
headerMap.remove(name);
}
@Override
public String getHeader(String name) {
String headerValue = super.getHeader(name);
if (headerMap.containsKey(name)) {
headerValue = headerMap.get(name);
}
return headerValue;
}
/**
* get the Header names
*/
@Override
public Enumeration<String> getHeaderNames() {
List<String> names = Collections.list(super.getHeaderNames());
for (String name : headerMap.keySet()) {
names.add(name);
}
return Collections.enumeration(names);
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> values = Collections.list(super.getHeaders(name));
if (headerMap.containsKey(name)) {
values.add(headerMap.get(name));
}
return Collections.enumeration(values);
}
}