@Configuration
public class InterceptorConfig implements WebMvcConfigurer {
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(new RestInterceptor()).addPathPatterns("/**");
registry.addInterceptor(new AccessInterceptor()).addPathPatterns("/**").excludePathPatterns("/error");
}
}
public class RestInterceptor extends HandlerInterceptorAdapter {
private static final Logger logger = LoggerFactory.getLogger(HandlerInterceptorAdapter.class);
private static ThreadLocal<Long> threadLocal = new ThreadLocal<>();
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
String msgTraceId = request.getParameter("trace_id");
if (StringUtils.isBlank(msgTraceId)) {
msgTraceId = LogTraceHelper.generateTraceId();
}
MDC.put(Const.LogbackConfig.TRACE_ID, msgTraceId);
logger.info("Rest come in, uri:{}, remote ip:{}", request.getRequestURI(), LogTraceHelper.getRemoteHost(request));
threadLocal.set(System.currentTimeMillis());
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
try {
Long startTimer = threadLocal.get();
if (startTimer != null) {
long endTimer = System.currentTimeMillis();
long costTime = endTimer - startTimer;
logger.info("Rest finished, cost time:{} ms", costTime);
}
} finally {
threadLocal.remove();
MDC.remove(Const.LogbackConfig.TRACE_ID);
}
}
}
public class AccessInterceptor extends HandlerInterceptorAdapter {
private static final Logger logger = LoggerFactory.getLogger(AccessInterceptor.class);
private static final String ACCESS_URI = "/internal/";
private static final String SWAGGER_URI = "swagger";
private static ThreadLocal<Boolean> requestInternal = new ThreadLocal<>();
private static UserTokenDao userTokenDao = ApplicationContextConfig.getBean("userTokenDao");
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
String requestUri = request.getRequestURI();
if (requestUri.contains(ACCESS_URI) || requestUri.contains(SWAGGER_URI)) {
return true;
}
String userId = request.getParameter("userId");
String token = request.getHeader("token");
String source = request.getParameter("source");
logger.info("api:{},userId:{},token:{},source:{}", requestUri, userId, token, source);
if (StringUtils.isBlank(userId) || StringUtils.isBlank(source)) {
throw new BusinessException(HttpStatus.BAD_REQUEST, "Illegal Param: userId or source");
}
if (StringUtils.isBlank(token) || !HashUtil.check(userId + Const.UNDERLINE + source, token)) {
String ip = LogTraceHelper.getRemoteHost(request);
logger.warn("Illegal Token Access of:{}, IP:{}", userId, ip);
throw new BusinessException(HttpStatus.FORBIDDEN, "访问非法");
}
Map<String, String> userInfo = userTokenDao.getUserToken(Long.valueOf(userId), Integer.valueOf(source));
if (userInfo == null || userInfo.size() == 0) {
logger.warn("token invalid or expire of:{}", userId);
throw new IllegalAuthorityException();
}
if (!token.equals(userInfo.get(Const.TokenInfo.TOKEN))) {
logger.warn("token invalid or expire of:{}", userId);
throw new IllegalAuthorityException();
}
userInfo.put(Const.TokenInfo.USER_ID, userId);
userInfo.put(Const.TokenInfo.IP, LogTraceHelper.getRemoteHost(request));
requestInternal.set(false);
UserSessionUtil.setUserInfo(userInfo);
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
if (requestInternal.get() != null) {
UserSessionUtil.clear();
requestInternal.remove();
}
}
}
public class UserSessionUtil {
private static final Logger logger = LoggerFactory.getLogger(UserSessionUtil.class);
private static ThreadLocal<Map<String, String>> threadLocal = new NamedThreadLocal<>(Thread.currentThread().getName() + " - sessionUser");
public static Long getUserId() {
Long userId = Long.valueOf(threadLocal.get().get(Const.TokenInfo.USER_ID));
logger.info("session get userId:{}", userId);
return userId;
}
public static String getEnterpriseId() {
String enterpriseId = threadLocal.get().get(Const.TokenInfo.ENTERPRISE_ID);
logger.info("session get enterpriseId:{}", enterpriseId);
return enterpriseId;
}
public static int getAuthority() {
int authority = Integer.parseInt(threadLocal.get().get(Const.TokenInfo.AUTHORITY));
logger.info("session get authority:{}", authority);
return authority;
}
public static String getIp() {
String ip = threadLocal.get().get(Const.TokenInfo.IP);
logger.info("session get ip:{}", ip);
return ip;
}
public static Integer getSource() {
Integer source = Integer.valueOf(threadLocal.get().get(Const.TokenInfo.SOURCE));
logger.info("session get source:{}", source);
return source;
}
public static String getPhone() {
String phone = threadLocal.get().get(Const.TokenInfo.PHONE);
logger.info("session get phone:{}", phone);
return phone;
}
public static void clear() {
if (threadLocal.get() != null) {
logger.info("session remove userId:{}", threadLocal.get().get(Const.TokenInfo.USER_ID));
threadLocal.remove();
}
}
public static void setUserInfo(Map<String, String> userInfo) {
logger.info("session set userInfo of:{}", userInfo.get(Const.TokenInfo.USER_ID));
threadLocal.set(userInfo);
}
}