一般来说,用户登录成功,服务端下发JWT,并将用户信息写入Redis。
用户调用业务接口,Filter拦截,校验JWT,并将Redis存储的用户信息写入ThreadLocal,例如LoginUserContext等。
Controller、Service等可以通过LoginUserContext获取当前请求的用户信息,校验一下是否有权限获取信息(订单的所有者与登录用户是否为同一人)、记录操作用户(一条记录的创建者、更新者)等。
我在实际的工作中,经常遇见ThreadLocal污染的局面(A用户登录了,上下文取到的却是B用户的信息),基本上都是没有及时销毁ThreadLocal所致。
虽然也很容易规避,但我一直想着能否不用ThreadLocal,毕竟我能力有限,才疏学浅,一看有“Thread”字样的,就犯怵。
因此参考一些资料,搞了以下方案:
通过在Filter向请求参数增加当前登录用户信息。
实现如下:
我们系统,一般就两种请求方式:GET、POST
POST只采用application/json的方式,Controller实现如下
@PostMapping("/order/add")
public void addOrder(@RequestBody OrderAddReq req){
log.info(GsonUtil.object2String(req));
}
GET 则有两种实现方式:
@GetMapping("/user/info")
public void userInfo(@RequestParam("userId") String userId){
log.info(userId);
}
@GetMapping("/user/getById")
public void getById(UserQueryReq req) {
log.info(GsonUtil.object2String(req));
}
前者适用于传参较少的情况,例如根据ID、订单号获取详情等,我们的开发规范一般要求参数<=3个的,用这种方式接收参数。
后者适用于传参较多的情况,例如列表查询,传什么状态、名称、时间范围、分页参数等。
Filter的思路就是针对以上情况,向Request内添加参数,这又有两种对参数的设置方式:
第一种方案:
@Data
public class LoginUserDto implements Serializable {
@Serial
private final static long serialVersionUID = 1L;
private String id;
}
@Data
public class UserQueryReq implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String userId;
private LoginUserDto loginUser;
}
@Data
public class UserAddReq implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String userId;
private String userName;
private LoginUserDto loginUser;
}
在Controller的实现改造如下:
@GetMapping("/user/info")
public void userInfo(@RequestParam("userId") String userId, @RequestParam("loginUser.id")String loginUserId){
log.info(userId);
log.info(loginUserId);
}
@GetMapping("/user/getById")
public void getById(UserReq req) {
log.info(GsonUtil.object2String(req));
}
@PostMapping("/user/add")
public void add(@RequestBody UserAddReq userAddReq) {
log.info(GsonUtil.object2String(userAddReq));
}
关键就在于第一个GET方法,增加了
@RequestParam("loginUser.id")String loginUserId。
取当前登录信息的方式就是
req.getLoginUser.getId();
第二种方案:
@Data
public class LoginUserBaseDto implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String loginUserId;
}
@EqualsAndHashCode(callSuper = true)
@Data
public class OrderQueryReq extends LoginUserBaseDto implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String orderNo;
}
@EqualsAndHashCode(callSuper = true)
@Data
public class OrderAddReq extends LoginUserBaseDto implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String orderNo;
}
Controller的实现:
@GetMapping("/order/info")
public void getOrderIndo(@RequestParam("orderNo") String orderNo, @RequestParam("loginUserId")String loginUserId){
log.info(orderNo);
log.info(loginUserId);
}
@GetMapping("/order/getById")
public void getOrderById(OrderQueryReq req){
log.info(GsonUtil.object2String(req));
}
@PostMapping("/order/add")
public void addOrder(@RequestBody OrderAddReq req){
log.info(GsonUtil.object2String(req));
}
关键就在于第一个GET方法,增加了
@RequestParam("loginUserId")String loginUserId。
取当前登录信息的方式就是
req.getLoginUserId();
下面是Filter的实现方法:
需要两个RequestWrapper。
public class GetParamRequestWrapper extends HttpServletRequestWrapper {
private final Map<String, String[]> params = new HashMap<>();
public GetParamRequestWrapper(HttpServletRequest request) {
super(request);
//将参数表,赋予给当前Map以便于持有request中的参数
this.params.putAll(request.getParameterMap());
}
public GetParamRequestWrapper(HttpServletRequest request, Map<String, String[]> extendParams) {
this(request);
addAllParameters(extendParams);
}
/**
* 重写getParameter方法
*
* @param name 参数名
* @return 参数数值
*/
@Override
public String getParameter(String name) {
String[] values = params.get(name);
if (values == null) {
return null;
}
return values[0];
}
@Override
public String[] getParameterValues(String name) {
String[] values = params.get(name);
if (values == null || values.length == 0) {
return null;
}
return values;
}
/**
* 在获取所有的参数名,必须重写此方法,
* 否则对象中参数值映射不上
*
* @return
*/
@Override
public Enumeration<String> getParameterNames() {
return new Vector<>(params.keySet()).elements();
}
public Map<String, String[]> getParams() {
return params;
}
public void addAllParameters(Map<String, String[]> extendParams) {
for (Map.Entry<String, String[]> entry : extendParams.entrySet())
addParameter(entry.getKey(), entry.getValue());
}
public void addParameter(String key, Object value) {
if (value != null) {
if (value instanceof String[])
params.put(key, (String[]) value);
else if (value instanceof String)
params.put(key, new String[]{(String) value});
else
params.put(key, new String[]{String.valueOf(value)});
}
}
public class PostParamRequestWrapper extends HttpServletRequestWrapper {
/**
* 每次读取8kb
*/
private static final int BUFFER_SIZE = 1024 * 8;
/**
* 请求体
*/
private String body;
/**
* 构造方法
*/
public PostParamRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
BufferedReader reader = request.getReader();
StringWriter writer = new StringWriter();
int read;
char[] buf = new char[BUFFER_SIZE];
while ((read = reader.read(buf)) != -1) {
writer.write(buf, 0, read);
}
this.body = writer.getBuffer().toString();
}
/**
* 获取请求体
*
* @return 请求体
*/
public String getBody() {
return body;
}
public void addLoginUser2Body(JSONObject loginUser) {
String json = this.body;
if (StringUtils.isBlank(json)) {
return;
}
JSONObject jo = new JSONObject(json);
jo.put("loginUser", loginUser);
this.body = jo.toString();
}
public void addLoginUserToBody(String key,String value) {
String json = this.body;
if (StringUtils.isBlank(json)) {
return;
}
JSONObject jo = new JSONObject(json);
jo.put(key,value);
this.body = jo.toString();
}
@Override
public ServletInputStream getInputStream() {
final ByteArrayInputStream bais = new ByteArrayInputStream(body.getBytes());
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() {
return bais.read();
}
};
}
@Override
public BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
}
Filter:
@Component
@Slf4j
public class AddLoginUserToReqFilter implements Filter {
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException,
ServletException {
HttpServletRequest hsr = (HttpServletRequest) servletRequest;
String url = hsr.getRequestURI();
String requestBody;
if ("GET".equals(hsr.getMethod())) {
GetParamRequestWrapper requestWrapper = new GetParamRequestWrapper(hsr, hsr.getParameterMap());
//第一种实现方式
//requestWrapper.addParameter("loginUser.id","1");
//第二种实现方式
requestWrapper.addParameter("loginUserId", "1");
requestBody = GsonUtil.object2String(requestWrapper.getParams());
log.info("【参数输出】URL:{} {},参数:{} ", requestWrapper.getMethod(), url, requestBody);
filterChain.doFilter(requestWrapper, servletResponse);
} else if ("POST".equals(hsr.getMethod())) {
PostParamRequestWrapper requestWrapper = new PostParamRequestWrapper(hsr);
JSONObject loginUser = new JSONObject();
loginUser.put("id", "2");
//第一种实现方式
//requestWrapper.addLoginUser2Body(loginUser);
//第二种实现方式
requestWrapper.addLoginUserToBody("loginUserId", "11111");
requestBody = requestWrapper.getBody();
log.info("【参数输出】URL:{} {},参数:{} ", requestWrapper.getMethod(), url, requestBody);
filterChain.doFilter(requestWrapper, servletResponse);
} else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
}
几点说明:
-
一般一个项目,Filter很多,例如校验权限什么的,本Filter可以通过@Order()设置在最后一个。
-
要有一个FilterConfig类,可以参考我以前的文章。
-
当前登录用户,可以通过HEADER获取JWT,进行校验,然后通过Redis获取用户信息,本例就写死了。