自定义权限管理:登录拦截

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO

联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬

在大家的印象中,对于一个系统而言,需要登录的接口占多数还是不需要登录的接口占多数呢?

答案是不一定。

如果是后台接口,那么基本都是要登录的,而对于前台接口或者APP,那就有可能五五开了。比如很多APP都提供了游客模式,下载APP能浏览一些问题和回答,但要发表文章或评论时,就会提醒你登录。

为什么要问这个问题呢?

因为实现登录拦截时,有两种方案:

  • 默认拦截全部,只对部分接口放行
  • 默认不拦截,只对部分接口限制

具体选哪种,取决于你的系统属于哪种。本文假定大部分接口无需登录,只对部分接口做限制。代码几乎一样,大家可以自己改改。

这里只贴出核心代码,一切工具类封装请参考小册其他章节。

AnnotationUtils:一个好用的注解工具类

在正式开始之前,我们来了解一下AnnotationUtils。

它是Spring提供的一个注解工具类,提供了获取类上的注解、方法上的注解以及注解属性等便利的操作。你是否曾经疑惑过:Spring为什么能识别“叠加的注解”?比如,Spring如何识别@RestController?

我的理解是,Spring在代码层面应该没有直接识别@RestController,只认@Controller+@ResponseBody,之所以能识别@RestController,是因为它“追踪”到了@RestController上面还有@Controller和@ResponseBody。Spring是如何“追踪”到嵌套注解的呢?答案就是AnnotationUtils!我们通过几个案例简单学习一下。

定义多个注解,并且存在嵌套:

@SecondLevel
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
@interface FirstLevel {
    String value();
    String info();
}

@ThirdLevel
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
@interface SecondLevel {
}

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
@interface ThirdLevel {
}

使用注解并读取:

@FirstLevel(value = "first", info = "写在类上面")
public class AnnotationUtilsTest {

    // ------ 读取注解 ------

    public static void main(String[] args) throws NoSuchMethodException {
        // 获取AnnotationUtilsTest的Class,利用AnnotationUtils获取类上的注解
        Class<?> clazz = AnnotationUtilsTest.class;
        FirstLevel firstLevel = AnnotationUtils.findAnnotation(clazz, FirstLevel.class);  	// yes
        SecondLevel secondLevel = AnnotationUtils.findAnnotation(clazz, SecondLevel.class); // yes
        ThirdLevel thirdLevel = AnnotationUtils.findAnnotation(clazz, ThirdLevel.class);	// yes

        // 获取AnnotationUtilsTest的Class,利用AnnotationUtils获取annotationMethod上的注解
        Method annotationMethod = clazz.getMethod("annotationMethod");
        FirstLevel firstLevel1 = AnnotationUtils.getAnnotation(annotationMethod, FirstLevel.class);		// yes
        SecondLevel secondLevel1 = AnnotationUtils.getAnnotation(annotationMethod, SecondLevel.class);	// yes
        ThirdLevel thirdLevel1 = AnnotationUtils.getAnnotation(annotationMethod, ThirdLevel.class);		// yes

		// 获取AnnotationUtilsTest的Class,利用AnnotationUtils获取noAnnotationMethod上的注解
        Method noAnnotationMethod = clazz.getMethod("noAnnotationMethod");
        FirstLevel firstLevel2 = AnnotationUtils.getAnnotation(noAnnotationMethod, FirstLevel.class); 	// null
        SecondLevel secondLevel2 = AnnotationUtils.getAnnotation(noAnnotationMethod, SecondLevel.class);// null
        ThirdLevel thirdLevel2 = AnnotationUtils.getAnnotation(noAnnotationMethod, ThirdLevel.class);	// null

        Object value = AnnotationUtils.getValue(firstLevel, "value");  // first
        Object info = AnnotationUtils.getValue(firstLevel, "info");	   // 写在类上面

        System.out.println("over");
    }

    // ------ 使用注解 ------

    @FirstLevel(value = "first", info = "写在方法上面")
    public void annotationMethod() {

    }

    public void noAnnotationMethod() {

    }

}

学完AnnotationUtils,我们正式开始~

SQL

CREATE TABLE `t_user` (
  `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
  `name` varchar(255) NOT NULL,
  `password` varchar(255) NOT NULL,
  `create_time` datetime DEFAULT CURRENT_TIMESTAMP,
  `update_time` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  `deleted` tinyint(1) DEFAULT '0',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=3 DEFAULT CHARSET=utf8mb4;

注解:@LoginRequired

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface LoginRequired {
}

通用组件

/**
 * 通用错误枚举
 *
 * @author mx
 */
@Getter
public enum ExceptionCodeEnum {

    /**
     * 通用结果
     */
    ERROR(-1, "网络错误"),
    SUCCESS(200, "成功"),
    NEED_LOGIN(10001, "需要登录"),
    PERMISSION_DENY(10002, "权限不足");

    private final Integer code;
    private final String desc;

    ExceptionCodeEnum(Integer code, String desc) {
        this.code = code;
        this.desc = desc;
    }

    private static final Map<Integer, ExceptionCodeEnum> cache = new HashMap<>();

    static {
        for (ExceptionCodeEnum exceptionCodeEnum : ExceptionCodeEnum.values()) {
            cache.put(exceptionCodeEnum.code, exceptionCodeEnum);
        }
    }

    public static String getDesc(Integer code) {
        return Optional.ofNullable(cache.get(code))
                .map(ExceptionCodeEnum::getDesc)
                .orElseThrow(() -> new IllegalArgumentException("invalid exception code!"));
    }

}
/**
 * 业务异常
 * biz是business的缩写
 * @see ExceptionCodeEnum
 */
@Getter
public class BizException extends RuntimeException {

    private ExceptionCodeEnum error;

    /**
     * 构造器,有时我们需要将第三方异常转为自定义异常抛出,同时又不想丢失原来的异常信息,此时可以传入cause
     *
     * @param error
     * @param cause
     */
    public BizException(ExceptionCodeEnum error, Throwable cause) {
        super(cause);
        this.error = error;
    }

    /**
     * 构造器,只传入通用错误枚举
     *
     * @param error
     */
    public BizException(ExceptionCodeEnum error) {
        this.error = error;
    }
}
/**
 * 全局异常处理
 */
@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {

    /**
     * 业务异常
     *
     * @param
     * @return
     */
    @ExceptionHandler(BizException.class)
    public Result<ExceptionCodeEnum> handleBizException(BizException bizException) {
        log.warn("业务异常:{}", bizException.getError().getDesc(), bizException);
        return Result.error(bizException.getError());
    }

    /**
     * 运行时异常
     *
     * @param e
     * @return
     */
    @ExceptionHandler(RuntimeException.class)
    public Result<ExceptionCodeEnum> handleRunTimeException(RuntimeException e) {
        log.warn("运行时异常: {}", e.getMessage(), e);
        return Result.error(ExceptionCodeEnum.ERROR);
    }

}
public abstract class WebConstant {

    public static final String CURRENT_USER_IN_SESSION = "current_user_in_session";
    public static final String USER_INFO = "user_info";
}
public final class ThreadLocalUtil {

    private ThreadLocalUtil() {
    }

    /**
     * ThreadLocal的静态方法withInitial()会返回一个SuppliedThreadLocal对象
     * 而SuppliedThreadLocal<T> extends ThreadLocal<T>
     * 我们存进去的Map会作为的返回值:
     * protected T initialValue() {
     *    return supplier.get();
     * }
     * 
     * 所以也相当于重写了initialValue()
     * 
     */
    private final static ThreadLocal<Map<String, Object>> THREAD_CONTEXT = ThreadLocal.withInitial(
            () -> new HashMap<>(8)
    );

    /**
     * 根据key获取value
     * 比如key为USER_INFO,则返回"{'name':'bravo', 'age':18}"
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object get(String key) {
        // getContextMap()表示要先获取THREAD_CONTEXT的value,也就是Map<String, Object>。然后再从Map<String, Object>中根据key获取
        return THREAD_CONTEXT.get().get(key);
    }

    /**
     * put操作,原理同上
     *
     * @param key
     * @param value
     */
    public static void put(String key, Object value) {
        THREAD_CONTEXT.get().put(key, value);
    }

    /**
     * 清除map里的某个值
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object remove(String key) {
        return THREAD_CONTEXT.get().remove(key);
    }

    /**
     * 清除整个Map<String, Object>
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {}
     * }
     */
    public static void clear() {
        THREAD_CONTEXT.get().clear();
    }

    /**
     * 从ThreadLocalMap中清除当前ThreadLocal存储的内容
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * }
     */
    public static void clearAll() {
        THREAD_CONTEXT.remove();
    }

}

登录拦截(关键代码)

@Configuration
public class MvcConfig implements WebMvcConfigurer {

    /**
     * 把拦截器通过registry注册到Spring容器
     * 一般有两种方式:
     * 1.registry.addInterceptor(new LoginInterceptor())
     * 2.给LoginInterceptor加@Component,通过@Autowired注入,然后registry.addInterceptor(loginInterceptor)
     * 【推荐使用@Autowired注入】:如果LoginInterceptor内部需要注入其他组件比如RedisTemplate,那么直接new LoginInterceptor会注入失败
     * @param registry
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new LoginInterceptor())
                .addPathPatterns("/**");
    }
}
public class LoginInterceptor extends HandlerInterceptorAdapter {

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        // 不拦截跨域请求相关
        if ("OPTIONS".equalsIgnoreCase(request.getMethod())) {
            return true;
        }

        // 如果方法上没有加@LoginRequired,无需登录,直接放行
        if (isLoginFree(handler)) {
            return true;
        }

        // 登录成功,把用户信息存入ThreadLocal
        User user = handleLogin(request, response);
        ThreadLocalUtil.put(WebConstant.USER_INFO, user);

        // 放行到Controller
        return super.preHandle(request, response, handler);
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        // 及时移除,避免ThreadLocal内存泄漏
        ThreadLocalUtil.remove(WebConstant.USER_INFO);
        super.afterCompletion(request, response, handler, ex);
    }

    /**
     * 接口是否免登录
     *
     * @param handler
     * @return
     */
    private boolean isLoginFree(Object handler) {
        // 判断是否支持免登录
        if (handler instanceof HandlerMethod) {
            HandlerMethod handlerMethod = (HandlerMethod) handler;
            Method method = handlerMethod.getMethod();
            // AnnotationUtils是Spring提供的注解工具类,还有很多其他便利的方法
            LoginRequired loginRequiredAnnotation = AnnotationUtils.getAnnotation(method, LoginRequired.class);
            // 没有加@LoginRequired,不需要登录
            return loginRequiredAnnotation == null;
        }

        return true;
    }

    /**
     * 登录校验
     *
     * @param request
     * @param response
     * @return
     */
    private User handleLogin(HttpServletRequest request, HttpServletResponse response) {
        HttpSession session = request.getSession();
        User currentUser = (User) session.getAttribute(WebConstant.CURRENT_USER_IN_SESSION);
        if (currentUser == null) {
            // 抛异常,请先登录(还有一种方式,就是利用response直接write返回JSON,但不推荐)
            throw new BizException(ExceptionCodeEnum.NEED_LOGIN);
        }
        return currentUser;
    }
}

测试

省略了Service层

@RestController
@RequestMapping("/user")
public class UserController {

    @Autowired
    private UserMapper userMapper;
    @Autowired
    private HttpSession session;

    @PostMapping("/register")
    public Result<User> register(@RequestBody User userInfo) {
        int rows = userMapper.insert(userInfo);
        if (rows > 0) {
            return Result.success(userInfo);
        }

        return Result.error("插入失败");
    }

    @PostMapping("/login")
    public Result<User> login(@RequestBody User loginInfo) {
        // 用的是MyBatis-Plus
        LambdaQueryWrapper<User> lambdaQuery = Wrappers.lambdaQuery();
        lambdaQuery.eq(User::getName, loginInfo.getName());
        lambdaQuery.eq(User::getPassword, loginInfo.getPassword());

        User user = userMapper.selectOne(lambdaQuery);
        if (user == null) {
            return Result.error("用户名或密码错误");
        }

        session.setAttribute(WebConstant.CURRENT_USER_IN_SESSION, user);
        return Result.success(user);
    }

    @LoginRequired
    @GetMapping("/needLogin")
    public Result<String> needLogin() {
        return Result.success("if you see this, you are logged in.");
    }

    @GetMapping("/needNotLogin")
    public Result<String> needNotLogin() {
        return Result.success("if you see this, you are logged in.");
    }
}

本文使用的是session的方式展示登录,JWT对于这个场景也是一样的。

另外,虽然@LoginRequired标明了可以使用在类或方法上,但是上面的代码只实现了对方法的判断,你能帮忙完善一下吗?

可以把你完善后的代码贴在评论区。 

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO

进群,大家一起学习,一起进步,一起对抗互联网寒冬
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值