反射初学学习

一、导言
利用反射可以通过类名访问字段,调用方法,调用构造方法,获取继承关系,动态代理(不去编写实现类,只写接口)。
二、简单学习
我们下面的例子是如何利用反射给字段赋值。

package com.itranswarp.learnjava;

import java.lang.reflect.Field;

public class Main {

	public static void main(String[] args) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
		String name = "Xiao Ming";
		int age = 20;
		Person p = new Person();
		// TODO: 利用反射给name和age字段赋值:
		Field f = p.getClass().getDeclaredField("name");
		//访问私有属性必加
		f.setAccessible(true);
		f.set(p,name);
		Field f2 = p.getClass().getDeclaredField("age");
		f2.setAccessible(true);
		f2.set(p,age);
		System.out.println(p.getName()); // "Xiao Ming"
		System.out.println(p.getAge()); // 20
	}
}

三、基础应用——对请求参数做处理

/**
 * 对请求参数做处理(默认拦截所有特殊字符,过滤掉 @AllowStr 注解设置或数据库中的白名单字符)
 */
@Slf4j
public class ParamCheckInterceptor extends HandlerInterceptorAdapter {

    //包装类和String
    private static Set<Class> BASE_CLS = new HashSet<>();

    static {
        BASE_CLS.add(String.class);
        BASE_CLS.add(Integer.class);
        BASE_CLS.add(Byte.class);
        BASE_CLS.add(Boolean.class);
        BASE_CLS.add(Short.class);
        BASE_CLS.add(Character.class);
        BASE_CLS.add(Long.class);
        BASE_CLS.add(Float.class);
        BASE_CLS.add(Double.class);
        BASE_CLS.add(BigDecimal.class);
    }

    @Autowired
    private CheckParam checkParam;
    @Autowired
    private TsecurityParamcheckFilterMapper paramcheckFilterMapper;
    @Autowired
    private TsecurityParamcheckStrFilterMapper paramcheckStrFilterMapper;
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        //不是controller方法 不拦截
        if(!(handler instanceof HandlerMethod)){
            return true;
        }
        HandlerMethod method = (HandlerMethod) handler;
        //仅对白名单以外的接口进行拦截
        if(isFilterApi(method.getBeanType(),method.getMethod().getName())){
            return true;
        }
        String requestMethod = request.getMethod();
        if("GET".equals(requestMethod)){
            //get请求参数处理
            return getMethodHandler(request, response, method);
        }else if("POST".equals(requestMethod)){
            //post请求参数处理
            return postMethodHandler(request,response,method);
        }
        return true;
    }

    /**
     * get请求参数校验
     * @param request
     * @param response
     * @param method
     * @throws Exception
     */
    private boolean getMethodHandler(HttpServletRequest request, HttpServletResponse response,HandlerMethod method) throws Exception{
        MethodParameter[] parameters = method.getMethodParameters();
        // 获取get请求参数
        Map<String, String[]> parameterMap = request.getParameterMap();
        List<ParamObj> paramObjs = new ArrayList<>();

        for(Map.Entry<String,String[]> param : parameterMap.entrySet()){
            ParamObj paramObj = new ParamObj();
            paramObj.setParamName(param.getKey());
            paramObj.setParamValue(param.getValue());
            paramObjs.add(paramObj);
        }

        List<TsecurityParamcheckStrFilter> paramFilterStrList = getParamFilterStrList(method.getBeanType(), method.getMethod().getName());

        //遍历方法的所有参数,判断是否含有白名单注解
        for(MethodParameter parameter : parameters){
            // 方法请求参数request和response不需校验
            if(parameter.getParameterType().equals(HttpServletRequest.class) || parameter.getParameterType().equals(HttpServletResponse.class)){
                continue;
            }
            Class paramType = parameter.getParameterType();
            //基本类型或包装类接收请求参数
            if(isBaseClass(paramType) || paramType.equals(List.class)){
                List<String> allowStrList = new ArrayList<>();
                List<String> aStr = getParamFilterStr(method.getBeanType(), method.getMethod().getName(), String.valueOf(parameter.getParameterIndex()));
                allowStrList.addAll(aStr);
                String paramName = parameter.getParameter().getName();
                if(parameter.hasParameterAnnotation(AllowStr.class)) {
                    AllowStr allowStr = parameter.getParameterAnnotation(AllowStr.class);
                    allowStrList.addAll(Arrays.asList(allowStr.value()));
                    if(allowStr.paramName() != null){
                        paramName = allowStr.paramName();
                    }
                }
                for(ParamObj paramObj : paramObjs){
                    if((paramName != null && paramObj.getParamName().equals(paramName))
                            || paramObj.getParamName().equals(parameter.getParameter().getName())){
                        paramObj.setAllowStr(allowStrList);
                        break;
                    }
                }
                continue;
            }
            // (非包装类)对象接收请求参数
            List<Field> fieldList = getFields(parameter);
            for(Field field : fieldList) {
                List<String> allowStrList = new ArrayList<>();
                List<String> aStr = getParamFilterStr(paramFilterStrList,String.valueOf(parameter.getParameterIndex()),"2",field.getName());
                allowStrList.addAll(aStr);
                if (field.isAnnotationPresent(AllowStr.class)) {
                    AllowStr allowStr = field.getAnnotation(AllowStr.class);
                    allowStrList.addAll(Arrays.asList(allowStr.value()));
                }
                for(ParamObj paramObj : paramObjs){
                    if((field.getName() != null && paramObj.getParamName().equals(field.getName()))
                            || paramObj.getParamName().equals(field.getName())) {
                        paramObj.setAllowStr(allowStrList);
                        break;
                    }
                }
            }
        }

        for(ParamObj paramObj : paramObjs){
            Map<String, Object> resultMap = checkList((String[])paramObj.getParamValue(), paramObj.getAllowStr());
            if(!(boolean) resultMap.get("status")){
                log.info("接口{}参数{}校验出错,请勿输入非法字符[{}]",getControllerApiName(method),paramObj.getParamName(),(String)resultMap.get("reason"));
                response.setContentType("text/html;charset=utf-8");
                response.getWriter().write(JSON.toJSONString(HttpBaseResponseUtil.getFailResponse("输入参数含有非法字符!")));
                response.getWriter().close();
                return false;
            }
        }
        return true;
    }

    /**
     * post请求参数校验
     * @param request
     * @param response
     * @param method
     * @return
     * @throws Exception
     */
    private boolean postMethodHandler(HttpServletRequest request,HttpServletResponse response,HandlerMethod method) throws Exception{
        MethodParameter[] parameters = method.getMethodParameters();
        Map<String,Object> paramObj = new HashMap<>();
        //获取query或application/x-www-form-urlencoded编码的请求参数
        Map<String, String[]> queryParamObj = request.getParameterMap();
        if(queryParamObj != null) {
            paramObj.putAll(queryParamObj);
        }
        // 获得post请求体json(请求流只能获取一次,但拦截器不会影响controller参数获取)
        String body = readAsChars(request);
        try {
            Map<String, Object> bodyParamObj = JSON.parseObject(body, Map.class);
            if (bodyParamObj != null) {
                paramObj.putAll(bodyParamObj);
            }
        }catch (Exception e){
        }
        // 将请求参数转换成处理实体类(参数名称 -> 参数实体类) x.paramName -> x
        List<ParamObj> paramObjs = new ArrayList<>();
        for(Map.Entry<String,Object> entry : paramObj.entrySet()){
            ParamObj param = new ParamObj();
            param.setParamName(entry.getKey());
            param.setParamValue(entry.getValue());
            paramObjs.add(param);
        }
        //map类型特殊处理(直接对前端所传map进行简单校验,无法处理复制类型)
        if(parameters.length == 1 && parameters[0].getParameterType().equals(Map.class)){
            for(ParamObj paramMap : paramObjs){
                if(ObjectUtil.isEmpty(paramMap.getParamValue())){
                    continue;
                }
                Map<String, Object> map = new HashMap<>();
                if(paramMap.getParamValue().getClass().isArray()){
                    map = checkList((String[])paramMap.getParamValue(),null);
                }else {
                    map = checkParam.checkControllerParam(String.valueOf(paramMap.getParamValue()), null);
                }
                if(!(boolean) map.get("status")){
                    log.info("接口{}参数{}校验出错,非法字符[{}]",getControllerApiName(method),paramMap.getParamName(),(String)map.get("reason"));
                    response.setContentType("text/html;charset=utf-8");
                    response.getWriter().write(JSON.toJSONString(HttpBaseResponseUtil.getFailResponse("输入参数含有非法字符!")));
                    response.getWriter().close();
                    return false;
                }
            }
            return true;
        }

        Map<String,ParamObj> paramObjMap = paramObjs.stream().collect(Collectors.toMap(ParamObj::getParamName,x -> x));

        List<TsecurityParamcheckStrFilter> paramFilterStrList = getParamFilterStrList(method.getBeanType(), method.getMethod().getName());

        // 遍历方法参数列表,找到其中含有AllowStr注解的field,加入白名单
        for(MethodParameter parameter : parameters){
            // 方法请求参数request和response不需校验
            if(parameter.getParameterType().equals(HttpServletRequest.class) || parameter.getParameterType().equals(HttpServletResponse.class)){
                continue;
            }
            //判断是否是基本类型或包装类接收参数
            if(isBaseClass(parameter.getParameterType())){
                //请求参数少传或者取到的参数名字为arg0类型——class文件没有存储参数名)
                if(paramObjMap.get(parameter.getParameter().getName()) == null){
                    // 判断是否采用注解方式。后备方案,如果编译时class没有存放paramName,通过注解获取参数名称
                    if(!(parameter.hasParameterAnnotation(AllowStr.class) && parameter.getParameterAnnotation(AllowStr.class).paramName() != null)){
                       //没有采用注解,也获取不到参数变量名,不校验该参数
                        continue;
                    }
                    //采用注解方式
                    List<String> allowStrList = new ArrayList<>();
                    AllowStr allowStr = parameter.getParameterAnnotation(AllowStr.class);
                    allowStrList.addAll(Arrays.asList(allowStr.value()));
                    if(paramObjMap.get(allowStr.paramName()) == null){
                        //请求参数少传,不校验
                        continue;
                    }
                    List<String> aStr = getParamFilterStr(paramFilterStrList,String.valueOf(parameter.getParameterIndex()),"1",allowStr.paramName());
                    allowStrList.addAll(aStr);
                    paramObjMap.get(allowStr.paramName()).setParamType(parameter.getParameterType());
                    paramObjMap.get(allowStr.paramName()).setAllowStr(allowStrList);
                }else {
                    paramObjMap.get(parameter.getParameter().getName()).setParamType(parameter.getParameterType());
                    List<String> allowStrList = new ArrayList<>();
                    if (parameter.hasParameterAnnotation(AllowStr.class)) {
                        AllowStr allowStr = parameter.getParameterAnnotation(AllowStr.class);
                        allowStrList.addAll(Arrays.asList(allowStr.value()));
                    }
                    List<String> aStr = getParamFilterStr(paramFilterStrList,String.valueOf(parameter.getParameterIndex()),"1",parameter.getParameter().getName());
                    allowStrList.addAll(aStr);
                    paramObjMap.get(parameter.getParameter().getName()).setAllowStr(allowStrList);
                }
            }else{
                //如果不是基本类型或包装类参数
                List<Field> fieldList = getFields(parameter);
                for (Field field : fieldList) {
                    if (paramObjMap.get(field.getName()) == null) {
                        continue;
                    }
                    paramObjMap.get(field.getName()).setParamType(getType(field));
                    List<String> allowStrList = new ArrayList<>();
                    if (field.isAnnotationPresent(AllowStr.class)) {
                        AllowStr allowStr = field.getAnnotation(AllowStr.class);
                        allowStrList.addAll(Arrays.asList(allowStr.value()));
                    }
                    List<String> aStr = getParamFilterStr(paramFilterStrList,String.valueOf(parameter.getParameterIndex()),"2",field.getName());
                    allowStrList.addAll(aStr);
                    paramObjMap.get(field.getName()).setAllowStr(allowStrList);
                }
            }
            for(Map.Entry<String,ParamObj> entry : paramObjMap.entrySet()){
                Map<String, Object> resultMap;
                //如果参数类型为空(前端传了多余参数或controller层方法不接受的多余参数)直接过,防止JSON.parseArray报栈溢出
                if(entry.getValue().getParamType() == null){
                    continue;
                }
                if(entry.getValue().getParamValue() == null || "".equals(entry.getValue().getParamValue())){
                     continue;
                }
                if(entry.getValue().getParamValue() instanceof JSONArray){
                    List<?> value = JSON.parseArray(entry.getValue().getParamValue().toString(),entry.getValue().getParamType());
                    resultMap = checkList(value.toArray(), entry.getValue().getAllowStr());
                }else if(entry.getValue().getParamValue().getClass().isArray()){
                    // url上所带参数处理
                    resultMap = checkList((String[])entry.getValue().getParamValue(), entry.getValue().getAllowStr());
                }else {
                    Object param = entry.getValue().getParamValue().toString();
                    if(!String.class.equals(entry.getValue().getParamType())){
                        //String类型转换报错,特殊处理
                        try {
                            param = JSON.parseObject(entry.getValue().getParamValue().toString(), entry.getValue().getParamType());
                        }catch (Exception e){
                            response.setContentType("text/html;charset=utf-8");
                            response.getWriter().write(JSON.toJSONString(HttpBaseResponseUtil.getFailResponse("输入参数格式错误")));
                            response.getWriter().close();
                            return false;
                        }
                    }
                    resultMap = checkObject(param, entry.getValue().getAllowStr());
                }
                if(!(boolean) resultMap.get("status")){
                    log.info("接口{}参数{}校验出错,非法字符[{}]",getControllerApiName(method),entry.getValue().getParamName(),(String)resultMap.get("reason"));
                    response.setContentType("text/html;charset=utf-8");
                    response.getWriter().write(JSON.toJSONString(HttpBaseResponseUtil.getFailResponse("输入参数含有非法字符!")));
                    response.getWriter().close();
                    return false;
                }
            }
        }
        return true;
    }

    private List<Field> getFields(MethodParameter parameter) {
        List<Field> fieldList = new ArrayList<>();
        Class cls = parameter.getParameterType();
        while (cls != null && cls != Object.class){
            Field[] fields = cls.getDeclaredFields();
            Collections.addAll(fieldList, fields);
            cls = cls.getSuperclass();
        }
        return fieldList;
    }

    /**
     * post请求体读取
     * @param request
     * @return
     */
    private String readAsChars(HttpServletRequest request) {

        BufferedReader br = null;
        StringBuilder sb = new StringBuilder("");
        try
        {
            br = request.getReader();
            String str;
            while ((str = br.readLine()) != null)
            {
                sb.append(str);
            }
            br.close();
        }
        catch (IOException e)
        {
            log.error("获取post请求body出错", e);
        }
        finally
        {
            if (null != br)
            {
                try
                {
                    br.close();
                }
                catch (IOException e)
                {
                    log.error("获取post请求body出错", e);
                }
            }
        }
        return sb.toString();
    }

    private <T> Map<String,Object> checkList(T paramValue[], List<String> allowStr) throws Exception{
        Map<String,Object> arrayParamResult = new HashMap<>();
        if(paramValue.getClass().isArray()){
            for(T param : paramValue){
                Map<String, Object> resultMap = checkObject(param, allowStr);
                if(!(boolean) resultMap.get("status")){
                    arrayParamResult.put("status",false);
                    arrayParamResult.put("reason",resultMap.get("reason"));
                    return arrayParamResult;
                }
            }
            arrayParamResult.put("status",true);
        }
        return arrayParamResult;
    }

    /**
     * 校验对象类型参数
     * @param object
     * @param allowStr
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<String,Object> checkObject (T object, List<String> allowStr) throws Exception{
        Class cls = object.getClass();
        if(Date.class.equals(cls)){
            //如果是DATE类型,过
            Map<String,Object> resultMap = new HashMap<>();
            resultMap.put("status",true);
            return resultMap;
        }
        List<Field> fieldList = getObjectAllField(cls);
        if(!isBaseClass(cls)){
            //如果是对象或list类型,递归调用,直到基本类型
            for(Field field : fieldList){
                Map<String, Object> map = new HashMap<>();
                map.put("status",true);
                if(field.getType().equals(List.class)){
                    if(field.get(object) != null) {
                        map = checkList(((List) field.get(object)).toArray(), allowStr);
                    }
                }else {
                    if(field.get(object) != null) {
                        map = checkObject(field.get(object), allowStr);
                    }
                }
                // 如果其中有一个field 检测不合规,直接返回
                if(!(boolean)map.get("status")){
                    return map;
                }
            }
            //全部field合规,返回
            Map<String,Object> resultMap = new HashMap<>();
            resultMap.put("status",true);
            return resultMap;
        }else{
            //如果是基本类型,直接校验
            return checkParam.checkControllerParam(object.toString(),allowStr);
        }
    }

    /**
     * 获取对象所有字段(包括父类,并将字段值设置为可访问)
     * @param cls
     * @return
     */
    private List<Field> getObjectAllField(Class cls){
        List<Field> fieldList = new ArrayList<>();
        while (cls != null && cls != Object.class) {
            Field[] fields = cls.getDeclaredFields();
            for (Field fd : fields) {
                if (!fd.isAccessible()) {
                    fd.setAccessible(true);
                }
            }
            Collections.addAll(fieldList, fields);
            cls = cls.getSuperclass();
        }
        return fieldList;
    }

    /**
     * 判断是否是“基本”类型,即可以直接校验的字段
     * @param cls
     * @return
     */
    private boolean isBaseClass(Class cls){
        return BASE_CLS.contains(cls);
    }

    /**
     * 这里只考虑list泛型、数组和普通类型,其他泛型如map不考虑
     * @param field
     * @return
     */
    private Class getType(Field field){
        if(field.getType().equals(List.class)){
            Type type = field.getGenericType();
            if(type instanceof ParameterizedType){
                ParameterizedType parameterizedType = (ParameterizedType) type;
                Class<?> accountPrincipalApproveClazz = (Class<?>)parameterizedType.getActualTypeArguments()[0];
                return accountPrincipalApproveClazz;
            }
        }else if(field.getType().isArray()){
            return field.getType().getComponentType();
        }
        return field.getType();
    }

    /**
     * 获取接口名(RequestMapping、getMapping、PostMapping)
     * @param handlerMethod
     * @return
     */
    private String getControllerApiName(HandlerMethod handlerMethod){
        Method method = handlerMethod.getMethod();
        if(method.isAnnotationPresent(RequestMapping.class)){
            RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
            return requestMapping.value()[0];
        }else if(method.isAnnotationPresent(GetMapping.class)){
            GetMapping getMapping = method.getAnnotation(GetMapping.class);
            return getMapping.value()[0];
        }else if(method.isAnnotationPresent(PostMapping.class)){
            PostMapping postMapping = method.getAnnotation(PostMapping.class);
            return postMapping.value()[0];
        }
        return method.getName();
    }

    /**
     * 判断是否是白名单接口
     * @param beanCls
     * @param methodName
     * @return
     */
    private boolean isFilterApi(Class beanCls,String methodName){
        TsecurityParamcheckFilterExample example = new TsecurityParamcheckFilterExample();
        List<TsecurityParamcheckFilter> methods = paramcheckFilterMapper.selectByExample(example);
        List<TsecurityParamcheckFilter> allMethodsOfBean = methods.stream().filter(x -> "1".equals(x.getFiltertype()))
                .collect(Collectors.toList());
        List<TsecurityParamcheckFilter> singleMethods = methods.stream().filter(x -> "2".equals(x.getFiltertype()))
                .collect(Collectors.toList());
        for(TsecurityParamcheckFilter allMethod : allMethodsOfBean){
            try {
                Class<?> cls = Class.forName(allMethod.getBean());
                Method[] declaredMethods = cls.getDeclaredMethods();
                for(Method method: declaredMethods){
                    if(method.isAnnotationPresent(RequestMapping.class) || method.isAnnotationPresent(PostMapping.class)
                        || method.isAnnotationPresent(GetMapping.class)){
                        TsecurityParamcheckFilter s = new TsecurityParamcheckFilter();
                        s.setBean(allMethod.getBean());
                        s.setMethod(method.getName());
                        s.setFiltertype("2");
                        singleMethods.add(s);
                    }
                }
            }catch (Exception e){
                log.error("获取Bean[{}]的所有接口出错",allMethod.getBean());
            }
        }
        for(TsecurityParamcheckFilter sM : singleMethods){
            try {
                if (Class.forName(sM.getBean()).equals(beanCls) && sM.getMethod().equals(methodName)){
                    return true;
                }
            }catch (Exception e){
                log.error("获取Bean[{}]的class出错",sM.getBean());
            }
        }
        return false;
    }

    /**
     * 获取数据库中的参数字符白名单
     * @param beanCls
     * @param methodName
     * @param paramIdx
     * @param paramLevel
     * @param paramName
     * @return
     */
    private List<String> getParamFilterStr(Class beanCls,String methodName,String paramIdx,String paramLevel
            ,String paramName){
        TsecurityParamcheckStrFilterExample example = new TsecurityParamcheckStrFilterExample();
        if(ObjectUtil.isEmpty(paramName)){
            example.createCriteria().andBeanEqualTo(beanCls.getName()).andMethodEqualTo(methodName).andParamIndexEqualTo(paramIdx)
                    .andParamLevelEqualTo(paramLevel);
        }
        else{
            example.createCriteria().andBeanEqualTo(beanCls.getName()).andMethodEqualTo(methodName).andParamIndexEqualTo(paramIdx)
                    .andParamLevelEqualTo(paramLevel).andParamNameEqualTo(paramName);
        }
        List<TsecurityParamcheckStrFilter> strFilters = paramcheckStrFilterMapper.selectByExample(example);
        List<String> allowStr = Arrays.asList(strFilters.stream().map(x -> x.getFilterStr()).collect(Collectors.toList()).stream().collect(Collectors.joining()).split(""));
        return allowStr;
    }

    private List<String> getParamFilterStr(Class beanCls,String methodName,String paramIdx){
        return this.getParamFilterStr(beanCls,methodName,paramIdx,"1",null);
    }

    private List<TsecurityParamcheckStrFilter> getParamFilterStrList(Class beanCls,String methodName){
        TsecurityParamcheckStrFilterExample example = new TsecurityParamcheckStrFilterExample();
        example.createCriteria().andBeanEqualTo(beanCls.getName()).andMethodEqualTo(methodName);
        return paramcheckStrFilterMapper.selectByExample(example);
    }

    private List<String> getParamFilterStr(List<TsecurityParamcheckStrFilter> list,String paramIdx,String paramLevel
            ,String paramName){
        List<TsecurityParamcheckStrFilter> filters = list.stream().filter(x -> x.getParamIndex().equals(paramIdx) && x.getParamLevel().equals(paramLevel)
                && x.getParamName().equals(paramName)).collect(Collectors.toList());
        List<String> allowStr = Arrays.asList(filters.stream().map(x -> x.getFilterStr()).collect(Collectors.toList()).stream().collect(Collectors.joining()).split(""));
        return allowStr;
    }

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值