【SpringBoot】SpringBoot 中使用自定义注解来实现接口参数校验

在后台接口做参数校验,一般有两种方案:

  • hibernate-validator
  • AOP + 自定义注解 实现方法级的参数校验

开发环境:

JDK:1.8
SpringBoot:2.1.1.RELEASE
IDEA:2019.1.1

1. hibernate-validator

hibernate-validator 是 Hibernate 项目中的一个数据校验框架,是 Bean Validation 的参考实现。

使用 hibernate-validator 能够将数据校验从业务代码中脱离出来,增加代码可读性。同时,也让数据校验变得更加方便、简单。

添加 hibernate-validator 依赖:

<dependency>
    <groupId>org.hibernate</groupId>
    <artifactId>hibernate-validator</artifactId>
    <version>4.3.1.Final</version>
</dependency>

【注意】:在 SpringBoot 2.1.1.RELEASE 中 不需要引入 Hibernate Validator , 因为 在引入的 spring-boot-starter-web(springbootweb启动器)依赖的时候中,内部已经依赖了 hibernate-validator 依赖包。

待校验的Vo:

public class ValidatorVo {

    @NotEmpty(message = "用户名不能为空")
    private String sName;

    @NotEmpty(message = "手机号不能为空")
    @Pattern(regexp = "^1[3|4|5|7|8][0-9]\\d{8}$", message = "手机号格式不正确")
    private String sPhone;

    @NotNull(message = "age 不能为空")
    private Integer age;

    //getter/setter
}

说明:

  • @NotEmpty:使用 hibernate-validator 校验规则

定义一个控制器:

@RestController
@RequestMapping("/param")
public class ParamController {

    @PostMapping("/validator")
    public String validator(@RequestBody @Valid ValidatorVo validatorVo) {
        return "TEST VALIDATOR";
    }

}

说明:

  • @Valid:对其参数启用验证

定义一个全局异常类:

@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {

    private static final String logExceptionFormat = "Capture Exception By GlobalExceptionHandler: Code: %s Detail: %s";

    @ExceptionHandler(value = {BindException.class, MethodArgumentNotValidException.class})
    public Object validationExceptionHandler(Exception ex) {
        return validationResult(1001, ex);
    }

    private <T extends Throwable> ResultVo validationResult(Integer code, T exception) {
        // 做日志记录处理
        doLog(code, exception);
        
        BindingResult bindResult = null;
        if (exception instanceof BindException) {
            bindResult = ((BindException) exception).getBindingResult();
        } else if (exception instanceof MethodArgumentNotValidException) {
            bindResult = ((MethodArgumentNotValidException) exception).getBindingResult();
        }
        String msg = null;
        if (null != bindResult && bindResult.hasErrors()) {
            msg = bindResult.getAllErrors().get(0).getDefaultMessage();
            if (msg.contains("NumberFormatException")) {
                msg = "参数类型错误!";
            }
        } else {
            msg = "系统繁忙,请稍后重试...";
        }
        return ResultVoUtil.error(code, msg);
    }

    private <T extends Throwable> void doLog(Integer status, T exception) {
        exception.printStackTrace();
        log.error(String.format(logExceptionFormat, status, exception.getMessage()));
    }

}

ResultVoUtil:返回前端对象信息工具类

public class ResultVoUtil {

    public static ResultVo success() {
        return success(null);
    }
    public static ResultVo success(Object object) {
        ResultVo result = new ResultVo();
        result.setCode(ResultCodeEnum.SUCCESS.getCode());
        result.setMsg("成功");
        result.setData(object);
        return result;
    }
    public static ResultVo success(Integer code, Object object) {
        return success(code, null, object);
    }
    public static ResultVo success(Integer code, String msg, Object object) {
        ResultVo result = new ResultVo();

        result.setCode(code);
        result.setMsg(msg);
        result.setData(object);
        return result;
    }

    public static ResultVo error(String msg) {
        ResultVo result = new ResultVo();
        result.setCode(ResultCodeEnum.ERROR.getCode());
        result.setMsg(msg);
        return result;
    }
    public static ResultVo error(Integer code, String msg) {
        ResultVo result = new ResultVo();
        result.setCode(code);
        result.setMsg(msg);
        return result;
    }

}
@Data
public class ResultVo<T> {

    // 错误码
    private Integer code;

    // 提示信息
    private String msg;

    // 返回的数据
    private T data;

    public boolean checkSuccess() {
        return ResultCodeEnum.SUCCESS.getCode().equals(this.code);
    }

}

好了,这就是使用了 hibernate-validator 校验。

2. AOP + 自定义注解 实现方法级的参数校验

hibernate-validator 是在实体类上添加注解;但对于不同的方法,所应用的校验规则也是不一样的,这样子可能就会需要创建多个实体类或者组,甚至于一些接口根本就没实体类参数;所以实际应用过程中还是有一定的困难;

所以,这里简单地实现了一套基于 自定义注解 + AOP 的方式实现接口参数校验框架。在方法体上使用@CheckParam 或者 @CheckParams 注解标注需要校验的参数

步骤一:自定义注解

自定义注解,给需要校验的方法进行注解。分为:单参数校验 @MyCheckParam 和多参数校验 @MyCheckParams

MyCheckParam:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyCheckParam {

    // 字段校验规则
    MyCheckParamEnum value() default MyCheckParamEnum.NOT_NULL;
    // 参数名称。用"."表示层级。最多支持2级。如:userVo.userName
    String argName();
    // 表达式。多个值用","分割。跟argName有关。
    String express() default "";
    // 自定义提示信息
    String msg() default "";

}

【注意】:MyCheckParamEnum 是一个枚举类哈。

MyCheckParams:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyCheckParams {
    
    // 多个MyCheckParam,由上至下校验
    MyCheckParam[] value();
    
}

可以校验多个参数。

步骤二:自定义枚举校验类

@Getter
public enum MyCheckParamEnum {

    NULL("参数必须为 null", MyCheckParamUtil::isNull),
    NOT_NULL("参数必须不为 null", MyCheckParamUtil::isNotNull),
    EMPTY("参数的必须为空", MyCheckParamUtil::isEmpty),
    NOT_EMPTY("参数必须非空", MyCheckParamUtil::isNotEmpty),
    LENGTH("参数长度必须在指定范围内", MyCheckParamUtil::inLength),
    GE("参数必须大于等于指定值", MyCheckParamUtil::isGreaterThanEqual),
    LE("参数必须小于等于指定值", MyCheckParamUtil::isLessThanEqual),
    ;

    private String msg;

    // 接收字段值(Object)和 表达式(String),返回是否符合规则(Boolean)
    private BiFunction<Object, String, Boolean> fun;

    MyCheckParamEnum(String msg, BiFunction<Object, String, Boolean> fun) {
        this.msg = msg;
        this.fun = fun;
    }

}

说明:

  1. 此枚举类是用于参数校验的。它有两个属性:msgfunmsg:参数校验不通过时的默认报错信息;fun:进行参数校验时,需要执行的方法,位于 MyCheckParamUtil 类中。一个枚举实例对应一个 fun

步骤三:自定义枚举校验工具类

校验时,会调用此类中的方法。根据校验规则不同,调用的方法不同。如:

校验规则是 NULL,则它调用的方法是:MyCheckParamUtil::isNull

public class MyCheckParamUtil {

    // 判断对象是否不为 null
    public static Boolean isNotNull(Object value, String express) {
        if (null == value) {
            return Boolean.FALSE;
        }
        return Boolean.TRUE;
    }

    public static Boolean isNull(Object value, String express) {
        return !isNotNull(value, express);
    }

    // 判断value !=null && length、size > 0
    public static Boolean isNotEmpty(Object value, String express) {
        if(isNull(value, express)) {
            return Boolean.FALSE;
        }
        if(value instanceof String && "".equals(((String) value).trim())) {
            return Boolean.FALSE;
        }
        if(value instanceof Collection && CollectionUtils.isEmpty((Collection) value)) {
            return Boolean.FALSE;
        }
        if (value instanceof Map && ((Map) value).isEmpty()) {
            return Boolean.FALSE;
        }
        return Boolean.TRUE;
    }

    public static Boolean isEmpty(Object value, String express) {
        return !isNotEmpty(value, express);
    }

    // 判断某个值的长度是否在某个范围
    public static Boolean inLength(Object value, String express) {
        if(isNull(value, express)) {
            return Boolean.FALSE;
        }
        if(null == express || "".equals(express)) {
            return Boolean.FALSE;
        }
        String[] split = express.split(",");
        if (null == split || split.length != 2) {
            return Boolean.FALSE;
        }
        if (value instanceof String) {
            Integer begin = Integer.valueOf(split[0].trim());
            Integer end = Integer.valueOf(split[1].trim());
            Integer length = ((String) value).length();
            return begin <= length && length <= end;
        }
        return Boolean.FALSE;
    }

    // 判断是否大于等于某个值
    public static Boolean isGreaterThanEqual(Object value, String express) {
        if (value == null) {
            return Boolean.FALSE;
        }
        if(value instanceof Integer) {
            return ((Integer) value) >= Integer.valueOf(express);
        }
        if(value instanceof Long) {
            return ((Long) value) >= Long.valueOf(express);
        }
        if(value instanceof Short) {
            return ((Short) value) >= Short.valueOf(express);
        }
        if(value instanceof Float) {
            return ((Float) value) >= Float.valueOf(express);
        }
        if(value instanceof Double) {
            return ((Double) value) >= Double.valueOf(express);
        }
        if(value instanceof String) {
            return ((String) value).length() >= Integer.valueOf(express);
        }
        if(value instanceof Collection) {
            return  ((Collection) value).size() >= Integer.valueOf(express);
        }
        return Boolean.FALSE;
    }

    // 判断是否大于等于某个值
    public static Boolean isLessThanEqual(Object value, String express) {
        if (value == null) {
            return Boolean.FALSE;
        }
        if(value instanceof Integer) {
            return ((Integer) value) <= Integer.valueOf(express);
        }
        if(value instanceof Long) {
            return ((Long) value) <= Long.valueOf(express);
        }
        if(value instanceof Short) {
            return ((Short) value) <= Short.valueOf(express);
        }
        if(value instanceof Float) {
            return ((Float) value) <= Float.valueOf(express);
        }
        if(value instanceof Double) {
            return ((Double) value) <= Double.valueOf(express);
        }
        if(value instanceof String) {
            return ((String) value).length() <= Integer.valueOf(express);
        }
        if(value instanceof Collection) {
            return  ((Collection) value).size() <= Integer.valueOf(express);
        }
        return Boolean.FALSE;
    }

}

步骤四:自定义 AOP

在 AOP 中,对注解进行解析、处理。主要逻辑在方法 doCheckParam()

MyCheckParamAspect:

@Slf4j
@Aspect
@Component
public class MyCheckParamAspect {

    // 单个参数校验切入点
    @Pointcut("@annotation(com.tinady.annotation.MyCheckParam)")
    public void doMyCheckParam() {}

    // 多个参数校验切入点
    @Pointcut("@annotation(com.tinady.annotation.MyCheckParams)")
    public void doMyCheckParams() {}

    // 单参数校验
    @Around("doMyCheckParam()")
    public Object doMyCheckParamAround(ProceedingJoinPoint joinPoint) throws Throwable {
        String msg = doCheckParam(joinPoint, false);
        // 参数校验未通过,则直接抛出自定义异常
        if (null != msg) {
            throw new MyCheckParamException(msg);
        }
        // 参数校验通过,则继续执行原来方法
        Object proceed = joinPoint.proceed();
        return proceed;
    }

    // 多参数校验
    @Around("doMyCheckParams()")
    public Object doMyCheckParamsAround(ProceedingJoinPoint joinPoint) throws Throwable {
        String msg = doCheckParam(joinPoint, true);
        // 参数校验未通过,则直接抛出自定义异常
        if (null != msg) {
            throw new MyCheckParamException(msg);
        }
        // 参数校验通过,则继续执行原来方法
        Object proceed = joinPoint.proceed();
        return proceed;
    }

    /**
    *
    * 功能描述: 参数校验
    *
    * @param joinPoint 切点
    * @param isMulti 是否是多参数校验
    * @return java.lang.String 错误信息
    * @date 2022-01-01 14:25
    */
    private String doCheckParam(ProceedingJoinPoint joinPoint, boolean isMulti) {
        Method method = this.getMethod(joinPoint);
        String[] paramNames = this.getParamNames(joinPoint);
        // 获取前端传递后台接口的所有入参对应的入参值
        Object[] arguments = joinPoint.getArgs();

        Boolean isValid = Boolean.TRUE;
        String msg = null;
        
        // 单参数校验
        if (!isMulti) {
            // AOP监听带注解的方法,所以不用判断注解是否为空
            MyCheckParam myCheckParam = method.getAnnotation(MyCheckParam.class);
            String argName = myCheckParam.argName();
            Object value = getParamValue(arguments, paramNames, argName);
            
            // 通过执行fun来校验value
            isValid = myCheckParam.value().getFun().apply(value, myCheckParam.express());
            
            msg = myCheckParam.msg();
            if (null == msg || "".equals(msg)) {
                msg = argName + ": " + myCheckParam.value().getMsg() + " " + myCheckParam.express();
            }
        } else {
            MyCheckParams myCheckParams = method.getAnnotation(MyCheckParams.class);
            MyCheckParam[] checkParams = myCheckParams.value();
            for (MyCheckParam checkParam : checkParams) {
                String argName = checkParam.argName();
                Object value = this.getParamValue(arguments, paramNames, argName);
                isValid = checkParam.value().getFun().apply(value, checkParam.express());
                // 只要有一个参数判断不通过,立即返回
                if (!isValid) {
                    msg = checkParam.msg();
                    if(null == msg || "".equals(msg)) {
                        msg = argName + ": " + checkParam.value().getMsg() + " " + checkParam.express();
                    }
                    break;
                }
            }
        }
        
        if (!isValid) {
            log.error("校验未通过");
            return msg;
        }
        log.info("校验通过");
        return null;
    }
}

说明:

  1. isValid = myCheckParam.value().getFun().apply(value, myCheckParam.express());:获取自定义注解中的 value,它是个枚举类。枚举类中有两个属性:msgfun。然后获取枚举中的 fun 方法,传入入参:value 和 自定义注解中的 express,然后执行此方法,获取其返回值

MyCheckParamAspect#getMethod():获取当前正在执行的方法

private Method getMethod(JoinPoint joinPoint) {
    MethodSignature methodSignature = (MethodSignature)joinPoint.getSignature();
    Method method = methodSignature.getMethod();
    if (method.getDeclaringClass().isInterface()) {
        try {
            method = joinPoint.getTarget().getClass().getDeclaredMethod(joinPoint.getSignature().getName(), method.getParameterTypes());
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        }
    }
    return method;
}

MyCheckParamAspect#getParamNames():获取当前正在执行的方法的入参

private String[] getParamNames(JoinPoint joinPoint) {
    MethodSignature methodSignature = (MethodSignature)joinPoint.getSignature();
    String[] parameterNames = methodSignature.getParameterNames();
    return parameterNames;
}

MyCheckParamAspect#getParamValue():获取参数对应的值

private Object getParamValue(Object[] arguments, String[] paramNames, String argName) {
    Object value = null;
    String name = argName;
    // 从对象中取值
    if (argName.contains(".")) {
        name = argName.split("\\.")[0];
    }
    int index = 0;
    for (String s : paramNames) {
        if (s.equals(name)) {
            value = arguments[index];
            break;
        }
        index++;
    }
    if (argName.contains(".")) {
        argName = argName.split("\\.")[1];
        JSONObject jsonObject = (JSONObject)JSONObject.toJSON(value);
        value = jsonObject.get(argName);
    }
    return value;
}

步骤五:自定义异常类

当参数校验不通过时,立即抛出异常。中断执行。

@Data
public class MyCheckParamException extends RuntimeException {

    // 错误码
    private Integer code;
    
    // 错误消息
    private String msg;

    public MyCheckParamException() {
        this(303, "参数错误");
    }

    public MyCheckParamException(String msg) {
        this(300, msg);
    }

    public MyCheckParamException(Integer code, String msg) {
        super(msg);
        this.code = code;
        this.msg = msg;
    }

}

步骤六:自定义统一异常处理类

抛出异常后,经过统一异常处理,然后返回给前台

@RestControllerAdvice
public class GlobalExceptionHandler {

	@ExceptionHandler(MyCheckParamException.class)
    public Object MyCheckParamException(MyCheckParamException myCheckParamException) {
        return ResultVoUtil.error(myCheckParamException.getCode(), myCheckParamException.getMsg());
    }
    
}

【注意】:ResultVoUtil 类上文有啊。这里就不再贴出了。

步骤七:使用自定义注解

Controller 中,哪个方法需要校验,就在哪个方法上加自定义注解 @MyCheckParam@MyCheckParams

@RestController
@RequestMapping("/param")
public class ParamController {

	@PostMapping("/oneChechNotNull")
    @MyCheckParam(value = MyCheckParamEnum.NOT_EMPTY, argName = "userName", msg = "草,这个是必填参数!")
    public String oneChechNotNull(String userName) {
        return "oneChechNotNull";
    }

    @PostMapping("/oneChechObjectAttrNotNull")
    @MyCheckParam(value = MyCheckParamEnum.NOT_EMPTY, argName = "userVo.name")
    public String oneChechObjectAttrNotNull(@RequestBody UserVo userVo) {
        return "oneChechObjectAttrNotNull";
    }

    @PostMapping("/oneChechLengthIn")
    @MyCheckParam(value = MyCheckParamEnum.LENGTH, argName = "password", express = "6,18", msg = "密码length必须在6-18位之间!")
    public String oneChechLengthIn(String password) {
        return "oneChechLengthIn";
    }

    @PostMapping("/multiCheckLengthIn")
    @MyCheckParams({
            @MyCheckParam(value = MyCheckParamEnum.GE, argName = "password", express = "6"),
            @MyCheckParam(value = MyCheckParamEnum.LE, argName = "password", express = "18"),
    })
    public String multiCheckLengthIn(String password) {
        return "multiCheckLengthIn";
    }
}

使用 AOP 实现接口参数校验就到这了。可以自己运行代码看看。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值