Spring拦截器对客户端签名认证

在本空间有一文章,描述了如何通过Postman工具自动添加请求报文的签名。其请求报文格式及签名位置参考《报文格式

下面给出服务器端如何签名认证的示例。 

 先定义拦截器:

package com.xxx.home.openapi.comm;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.security.NoSuchAlgorithmException;
import java.text.SimpleDateFormat;
import java.time.Clock;
import java.util.*;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.web.servlet.error.BasicErrorController;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import com.xxx.home.config.ThrPlatInfoConfig;
import com.xxx.home.config.ThrPlatInfoConfig.ThrPlatConfigProp;
import com.xxx.home.openapi.constant.CommResultCode;
import com.xxx.home.openapi.exception.BussinessException;
import com.xxx.home.utils.MDUtils;

/**
 * 三方平台访问开放网关时,统一拦截验证请求报文的签名。</br>
 * 请求报文的验证签名规则如下:</br>
 * <li>请求地址栏中需要以下请求参数:?signValidate={[on|off|}&custId={custId}&sign={sign}&other=xxx</li>
 * <li>signValidate={[on|off|},Beta环境默认关闭;prd环境设置无效,永久开启。</li>
 * <li>custId是开放平台为三方平台分配,同时分配{接入密钥}。</li>
 * <li>partA = 请求地址栏中除sign参数以外,以参数名的ASCII升序排列,将其对应的参数键值以key1=value1&拼接,如参数值为空不参与。</li>
 * <li>partB = custId对应的{接入密钥}。</li>
 * <li>partC = 请求体报文(如遇回车等,请原样保留)。</li>
 * <li>sign的签名规则:MD5(partA + partB + partC, UTF8)的十六进制小写字符串</li>
 * <li>计算得出的sign与传入的sign是否一致,不一致拒绝服务。一致的情况下进行下一拦截。</li>
 */
@Component
public class ApiSignInterceptor implements HandlerInterceptor {

    private final Logger logger = LoggerFactory.getLogger(getClass());

    private static final String USERAGENT = "user-agent";

    @Autowired
    private ThrPlatInfoConfig thrPlatInfoConfig;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {

        long startTime = Clock.systemDefaultZone().millis();
        request.setAttribute(KEY_STARTTIME, startTime);
        log(request);

        if (handler instanceof HandlerMethod) {
            HandlerMethod handlerMethod = (HandlerMethod) handler;
            if (handlerMethod.getBean() instanceof BasicErrorController) {// exclude frame BasicErrorController
                return true;
            }
            MyRequestWrapper mrw = new MyRequestWrapper(request);
            // 验签前打印请求内容
            printSysLog((HandlerMethod)handler, mrw);

            ValidateResponse validateResponse = paramSignValidate(mrw);
            if (!validateResponse.isValidate()) {
                throw validateResponse.getException();
            }
        }
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
            ModelAndView modelAndView) {
        // Controller 方法调用之后执行,但是它会在DispatcherServlet 进行视图返回渲染之前被调用,此处不需要任何处理
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler,
            Exception ex) {
        Long startTimeMills = (Long) request.getAttribute(KEY_STARTTIME);
        
        //如果controller报错,则记录异常错误
        if(ex != null){
            logger.error("Controller exception {}", getStackTraceAsString(ex));
        }

        if (null != startTimeMills) {
            logger.info("Controller cost(ms):{}", Clock.systemDefaultZone().millis() - startTimeMills);
        }
        request.removeAttribute(KEY_STARTTIME);
    }

    /**
     * 拦截器在HttpRequest对象上自定义添加请求开始时间参数
     */
    private static final String KEY_STARTTIME = "__startTime";

    private void printSysLog(HandlerMethod h, MyRequestWrapper request) {
        StringBuilder sb = new StringBuilder();
        sb.append("URL Params: ").append(getParamString(request.getParameterMap())).append("\n");
        sb.append("Body      : ").append(request.getBody()).append("\n");
        sb.append("URI       : ").append(request.getRequestURI());
        logger.info("HttpRequest:{}", sb);
    }

    private String getParamString(Map<String, String[]> map) {
        StringBuilder sb = new StringBuilder();
        for (Entry<String, String[]> e : map.entrySet()) {
            sb.append(e.getKey()).append("=");
            String[] value = e.getValue();
            if (value != null && value.length == 1) {
                sb.append(value[0]).append("&");
            }
            else {
                sb.append(Arrays.toString(value)).append("&");
            }
        }
        return sb.toString();
    }

    /**
     * 将ErrorStack转化为String.
     */
    private String getStackTraceAsString(Throwable e) {
        if (e == null){
            return "";
        }
        StringWriter stringWriter = new StringWriter();
        e.printStackTrace(new PrintWriter(stringWriter));
        return stringWriter.toString();
    }

    @Autowired
    private Environment env;

    /**
     * 签名校验
     *
     * @param request HttpServletRequest
     * @param response HttpServletResponse
     * @return 验签结果
     */
    private ValidateResponse paramSignValidate(MyRequestWrapper request) {
        // 得到系统环境
        String activeEnv = env.getProperty("spring.profiles.active");
        String signValidate = request.getParameter("signValidate");
        // beta环境允许客户自己决定是否使用签名机制验证报文
        if ("beta".equalsIgnoreCase(activeEnv)
                && (StringUtils.isEmpty(signValidate) || "off".equalsIgnoreCase(signValidate))) {
            return new ValidateResponse(true, null);
        }

        String custId = request.getParameter("custId");
        String signThrplat = request.getParameter("sign");
        if (StringUtils.isEmpty(signThrplat) || StringUtils.isEmpty(custId)) {
            return new ValidateResponse(false, new BussinessException(CommResultCode.AK_FAILURE));
        }

        // 根据ID获取对应的密钥
        ThrPlatConfigProp thrPlatConfigInfo = thrPlatInfoConfig.getConfigByCustId(custId);
        if (thrPlatConfigInfo == null) {
            return new ValidateResponse(false, new BussinessException(CommResultCode.AK_FAILURE));
        }
        String appSecret = thrPlatConfigInfo.getAppSecret();
        if (StringUtils.isEmpty(appSecret)) {
            return new ValidateResponse(false, new BussinessException(CommResultCode.AK_FAILURE));
        }

        // 将参数按照一定规则获取到sign和客户端传过来的sign进行比较
        String signMe;
        try {
            signMe = getSign(request, appSecret);
        }
        catch (IOException e) {
            return new ValidateResponse(false, new BussinessException(CommResultCode.SYSTEM_ERROR));
        }
        if (!signThrplat.equalsIgnoreCase(signMe)) {
            return new ValidateResponse(false, new BussinessException(CommResultCode.AK_FAILURE));
        }
        return new ValidateResponse(true, null);
    }

    private String getSign(MyRequestWrapper request, String appSecret) throws IOException {
        String bodyStr = request.getBody();
        TreeMap<String, String> params = new TreeMap<>();
        Enumeration<String> enu = request.getParameterNames();
        while (enu.hasMoreElements()) {
            String paramName = enu.nextElement().trim();
            String paramValue = request.getParameter(paramName);
            if (!paramName.equals("sign") && StringUtils.isNotEmpty(paramValue)) {
                params.put(paramName, paramValue);
            }
        }

        StringBuilder paramValues = new StringBuilder();
        int i = 0;
        for (Map.Entry<String, String> entry : params.entrySet()) {
            if (i > 0) {
                paramValues.append("&");
            }
            paramValues.append(entry.getKey()).append("=").append(entry.getValue());
            i++;
        }
        paramValues.append(appSecret);
        paramValues.append(bodyStr);

        try {
            return MDUtils.md5EncodeForHex(paramValues.toString(), "utf-8");
        }
        catch (UnsupportedEncodingException e) {
            logger.error("UTF-8 is unsupported", e);
            throw new BussinessException(CommResultCode.SYSTEM_ERROR);
        }
        catch (NoSuchAlgorithmException e) {
            logger.error("MessageDigest不支持MD5", e);
            throw new BussinessException(CommResultCode.SYSTEM_ERROR);
        }
    }

    /**
     * 校验返回对象
     */
    private static class ValidateResponse {
        private boolean validate;
        private BussinessException exception;

        public ValidateResponse(boolean validate, BussinessException exception) {
            this.validate = validate;
            this.exception = exception;
        }

        public boolean isValidate() {
            return validate;
        }

        public Exception getException() {
            return exception;
        }
    }

    private String log(HttpServletRequest request) {
        // 过滤每次请求 提取req信息
        String userAgent = "";
        try {
            String accessURL = request.getRequestURL().toString();
            String clientIP = getReqIp(request);
            String referer = request.getHeader("referer");
            if (null != referer && referer.length() > 300) {
                referer = referer.substring(0, 297) + "...";
            }
            JSONObject accessLogJOSN = new JSONObject();
            String accessTimeStr = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date());
            accessLogJOSN.put("accessURL", accessURL);
            accessLogJOSN.put("accessIP", clientIP);
            accessLogJOSN.put("accessTimeStr", accessTimeStr);
            Map parameterMap = request.getParameterMap();
            Iterator<Map.Entry<String, Object>> iterator = parameterMap.entrySet().iterator();
            while (iterator.hasNext()) {
                Map.Entry<String, Object> entry = iterator.next();
                if (entry.getValue() instanceof String[]) {
                    accessLogJOSN.put(entry.getKey(), ((String[]) entry.getValue())[0]);
                } else {
                    accessLogJOSN.put(entry.getKey(), entry.getValue());
                }
            }
            userAgent = request.getHeader(USERAGENT);
            if (null != userAgent) {
                accessLogJOSN.put("user-agent", userAgent);
            }
            if (null != referer) {
                accessLogJOSN.put("referer", referer);
            }
            String pro = request.getHeader("X-Forwarded-Proto");
            if (com.xxx.home.utils.StringUtils.isNotEmpty(pro)) {
                accessLogJOSN.put("Proto", pro);
            }
            logger.info("accessLog-->" + accessLogJOSN.toString());
        } catch (Exception e) {
            logger.error("打印log出错" + e);
        }
        return userAgent;
    }

    /**
     * 获取远程地址
     * @param request
     * @return
     */
    private String getReqIp(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || ip.equalsIgnoreCase("unknown")) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || ip.equalsIgnoreCase("unknown")) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || ip.equalsIgnoreCase("unknown")) {
            ip = request.getRemoteAddr();
        }
        if (ip != null) {
            if (ip.contains(",")) {
                String[] ips = ip.split(",");
                for(String eachIp : ips) {
                    if (isIP(eachIp)) {
                        ip = eachIp;
                        break;
                    }
                }
            }
        }
        return ip;
    }

    private boolean isIP(String addr) {
        if (addr != null && addr.length() >= 7 && addr.length() <= 15 && !"".equals(addr)) {
            String rexp = "([1-9]|[1-9]\\d|1\\d{2}|2[0-4]\\d|25[0-5])(\\.(\\d|[1-9]\\d|1\\d{2}|2[0-4]\\d|25[0-5])){3}";
            Pattern pat = Pattern.compile(rexp);
            Matcher mat = pat.matcher(addr);
            boolean ipAddress = mat.find();
            return ipAddress;
        } else {
            return false;
        }
    }
}

将其放置于拦截器链中

/**
 * 开放网关WebMVC容器相关的配置类。</br>
 */
@EnableWebMvc
@Configuration
public class OpenAPIWebConfig implements  WebMvcConfigurer {

    @Override
    public  void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(getHandlerInterceptor()).addPathPatterns("/ulehomeapi/**");
    }

    @Override
    public void addCorsMappings(CorsRegistry registry) {
        registry.addMapping("/**")
                .allowedHeaders("Content-Type", "x-requested-with", "X-Custom-Header")
                .allowedMethods("PUT", "POST", "GET", "DELETE", "OPTIONS")
                .allowedOrigins("*")
                .allowCredentials(true);
    }

    @Bean
    public HandlerInterceptor getHandlerInterceptor() {
        return new ApiSignInterceptor();
    }

    @Autowired
    private HttpMessageConverter<?> fastJsonHttpMessageConverter;

    @Override
    public void configureMessageConverters(List<HttpMessageConverter<?>> converters) {
        converters.add(fastJsonHttpMessageConverter);
    }
}

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值