执行流程说明请看上一篇。本篇只涉及代码编写
1、新建util。放入以下类
BaseRedisService
import java.util.concurrent.TimeUnit;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
/**
* BaseRedisService
*
*/
@Component
public class BaseRedisService {
@Autowired
private StringRedisTemplate stringRedisTemplate;
public void setString(String key, Object data, Long timeout) {
if (data instanceof String) {
String value = (String) data;
stringRedisTemplate.opsForValue().set(key, value);
}
if (timeout != null) {
stringRedisTemplate.expire(key, timeout, TimeUnit.SECONDS);
}
}
public Object getString(String key) {
return stringRedisTemplate.opsForValue().get(key);
}
public void delKey(String key) {
stringRedisTemplate.delete(key);
}
/**
* 从redis中获取key对应的过期时间;
* 如果该值有过期时间,就返回相应的过期时间;
* 如果该值没有设置过期时间,就返回-1;
* 如果没有该值,就返回-2;
*
* @param key
* @return
*/
public String getExpire(String key) {
return stringRedisTemplate.opsForValue().getOperations().getExpire(key) + "";
}
}
Md5Util
import java.security.MessageDigest;
/**
* MD5Util
*
*/
public class Md5Util {
public final static String getMD5(String s) {
char[] hexDigits={'0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F'};
try {
byte[] btInput = s.getBytes();
// 获得MD5摘要算法的 MessageDigest 对象
MessageDigest mdInst = MessageDigest.getInstance("MD5");
// 使用指定的字节更新摘要
mdInst.update(btInput);
// 获得密文
byte[] md = mdInst.digest();
// 把密文转换成十六进制的字符串形式
int j = md.length;
char[] str = new char[j * 2];
int k = 0;
for (int i = 0; i < j; i++) {
byte byte0 = md[i];
str[k++] = hexDigits[byte0 >>> 4 & 0xf];
str[k++] = hexDigits[byte0 & 0xf];
}
return new String(str);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}
SortUtils
import com.alibaba.fastjson.JSONObject;
import com.app.config.ResourceConfig;
import org.junit.platform.commons.util.StringUtils;
import java.util.*;
import java.util.Map.Entry;
/**
* SortUtils
*
* @Author zhanglijun
*/
public class SortUtils {
public static ResourceConfig resourceConfig = (ResourceConfig) SpringUtil.getBean("resourceConfig");
/**
* 同程接口访问参数格式化
*
* @param param Map格式的参数
* @param isLower 是否小写
* @return
*/
public static String formatUrlParam(Map<String, Object> param, boolean isLower) {
String params = "";
Map<String, Object> map = param;
try {
List<Map.Entry<String, Object>> items = new ArrayList<Map.Entry<String, Object>>(map.entrySet());
//对所有传入的参数按照字段名从小到大排序
//可通过实现Comparator接口的compare方法来完成自定义排序
Collections.sort(items, new Comparator<Map.Entry<String, Object>>() {
@Override
public int compare(Entry<String, Object> o1, Entry<String, Object> o2) {
return (o1.getKey().toString().compareTo(o2.getKey()));
}
});
//构造URL 键值对的形式
StringBuffer sb = new StringBuffer();
for (Map.Entry<String, Object> item : items) {
if (StringUtils.isNotBlank(item.getKey())) {
String key = item.getKey();
String val = String.valueOf(item.getValue());
if (isLower) {
sb.append(key.toLowerCase() + "=" + val);
} else {
sb.append(key + "=" + val);
}
sb.append("&");
}
}
sb.append("md5key" + "=" + resourceConfig.getTcMd5Key() + "&");
params = sb.toString();
if (!params.isEmpty()) {
params = params.substring(0, params.length() - 1);
}
} catch (Exception e) {
return "";
}
return params;
}
/**
* 同程接口访问参数格式化
*
* @param jsonObject jsonObject格式的参数
* @param isLower 是否小写
* @return
*/
public static String formatUrlParam(JSONObject jsonObject, boolean isLower) {
String params = "";
JSONObject jsonObject1 = jsonObject;
Set<Entry<String, Object>> entries = jsonObject1.entrySet();
try {
List<Map.Entry<String, String>> itmes = new ArrayList(jsonObject1.entrySet());
//对所有传入的参数按照字段名从小到大排序
//可通过实现Comparator接口的compare方法来完成自定义排序
Collections.sort(itmes, new Comparator<Map.Entry<String, String>>() {
@Override
public int compare(Entry<String, String> o1, Entry<String, String> o2) {
return (o1.getKey().toString().compareTo(o2.getKey()));
}
});
//构造URL 键值对的形式
StringBuffer sb = new StringBuffer();
for (Map.Entry<String, String> item : itmes) {
if (StringUtils.isNotBlank(item.getKey())) {
String key = item.getKey();
String val = item.getValue();
if (isLower) {
sb.append(key.toLowerCase() + "=" + val);
} else {
sb.append(key + "=" + val);
}
sb.append("&");
}
}
sb.append("md5key" + "=" + resourceConfig.getTcMd5Key() + "&");
params = sb.toString();
if (!params.isEmpty()) {
params = params.substring(0, params.length() - 1);
}
} catch (Exception e) {
return "";
}
return params;
}
/**
* 外部接口访问格式化参数
*
* @param param map格式的参数
* @param isLower 是否小写
* @param appSecret 商户的appSecret
* @return
*/
public static String hxFormatUrlParam(Map<String, Object> param, boolean isLower, String appSecret) {
String params = "";
Map<String, Object> map = param;
try {
List<Map.Entry<String, Object>> itmes = new ArrayList<Map.Entry<String, Object>>(map.entrySet());
//对所有传入的参数按照字段名从小到大排序
//可通过实现Comparator接口的compare方法来完成自定义排序
Collections.sort(itmes, new Comparator<Map.Entry<String, Object>>() {
@Override
public int compare(Entry<String, Object> o1, Entry<String, Object> o2) {
return (o1.getKey().toString().compareTo(o2.getKey()));
}
});
//构造URL 键值对的形式
StringBuffer sb = new StringBuffer();
for (Map.Entry<String, Object> item : itmes) {
if (StringUtils.isNotBlank(item.getKey())) {
String key = item.getKey();
String val = String.valueOf(item.getValue());
if (isLower) {
sb.append(key.toLowerCase() + "=" + val);
} else {
sb.append(key + "=" + val);
}
sb.append("&");
}
}
params = sb.toString();
if (!params.isEmpty()) {
params = params.substring(0, params.length() - 1);
}
params = params + appSecret;
} catch (Exception e) {
return "";
}
return params;
}
}
SpringUtil
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
/**
* SpringBoot中普通类无法通过@Autowired自动注入,可通过此util注入,格式:StringUtil.getBean(XXX);
*
*/
@Component
public class SpringUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext = null;
public SpringUtil() {
}
@Override
public void setApplicationContext(ApplicationContext arg0) throws BeansException {
if (applicationContext == null) {
applicationContext = arg0;
}
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
public static void setAppCtx(ApplicationContext webAppCtx) {
if (webAppCtx != null) {
applicationContext = webAppCtx;
}
}
/**
* 拿到ApplicationContext对象实例后就可以手动获取Bean的注入实例对象
*/
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
public static <T> T getBean(String name, Class<T> clazz) throws ClassNotFoundException {
return getApplicationContext().getBean(name, clazz);
}
public static final Object getBean(String beanName) {
return getApplicationContext().getBean(beanName);
}
public static final Object getBean(String beanName, String className) throws ClassNotFoundException {
Class clz = Class.forName(className);
return getApplicationContext().getBean(beanName, clz.getClass());
}
public static boolean containsBean(String name) {
return getApplicationContext().containsBean(name);
}
public static boolean isSingleton(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().isSingleton(name);
}
public static Class<?> getType(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().getType(name);
}
public static String[] getAliases(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().getAliases(name);
}
}
BeanCopyUtil
import org.springframework.cglib.beans.BeanCopier;
/**
* 复制对象
*
*/
public class BeanCopyUtil {
/**
* 无转换方法只能复制相同名称相同类型的bean
*
* @param srcObj 源对象
* @param targetObj 目的对象
*/
public static void copy(Object srcObj, Object targetObj) {
BeanCopier beanCopier = BeanCopier.create(srcObj.getClass(), targetObj.getClass(), false);
beanCopier.copy(srcObj, targetObj, null);
}
}
2、新建handler,放入以下类(核心拦截)
CheckInterceptor
package com.app.handler;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.parser.Feature;
import com.app.constant.ErrorEnum;
import com.app.entity.App;
import com.app.service.AppService;
import com.wii.spring.web.payload.ResponseContent;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.*;
/**
* checkInterceptor
*
*/
@Slf4j
@Component
public class CheckInterceptor implements HandlerInterceptor {
@Autowired
private BaseRedisService baseRedisService;
@Autowired
private AppService appService;
/**
* 进入controller层之前拦截请求,进行参数校验。
* 1、获取sign和accesstoken进行非空校验
* 2、遍历所有的key赋值到map中,移除map中的sign
* 3、根据accessToken查看redis中是否存在,存在则取accessToken中对应的id,根据id去数据库查此商户数据。能查到则取appSecret
* 4、对所有的参数进行字典升序小写,在升序的最后面加上商户的appSecret
* 5、对排序后的内容进行MD5加密生成sign
* 6、判断传入的sign和自己排序加密生成的sign是否一致,一致则通过,不一致则提示
*
* @param httpServletRequest
* @param httpServletResponse
* @param o
* @return
* @throws Exception
*/
@Override
public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o) throws IOException {
BufferedReader streamReader = null;
try {
streamReader = new BufferedReader(new InputStreamReader(httpServletRequest.getInputStream(), "UTF-8"));
StringBuilder responseStrBuilder = new StringBuilder();
String inputStr;
while ((inputStr = streamReader.readLine()) != null) {
responseStrBuilder.append(inputStr);
}
//json获取,value不要排序。
JSONObject jsonObject = JSONObject.parseObject(responseStrBuilder.toString(), Feature.OrderedField);
//获取sign
String sign = jsonObject.getString("sign");
//获取access_token
String accessToken = jsonObject.getString("access_token");
if (StringUtils.isEmpty(accessToken)) {
//校验失败,缺少accessToken
resultError("access_token验证失败", httpServletResponse);
return false;
}
if (StringUtils.isEmpty(sign)) {
//校验失败,缺少sign
resultError("验证失败", httpServletResponse);
return false;
}
LinkedHashMap map = new LinkedHashMap();
Set<Map.Entry<String, Object>> entries = jsonObject.entrySet();
//遍历所有的key
Iterator iter = entries.iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
map.put(entry.getKey(), entry.getValue());
}
map.remove("sign");
//根据accessToken获取id
Object obj = baseRedisService.getString(accessToken);
String appSecret = "";
//获取appSecret
if (obj != null) {
String id = obj.toString();
//根据id查询数据
App app = appService.findById(id).get();
if (app == null) {
resultError("access_token验证失败", httpServletResponse);
return false;
}
appSecret = app.getAppSecret();
}
//对所有的参数进行字典升序,小写
String paramStr = SortUtils.hxFormatUrlParam(map, true, appSecret);
//MD5加密生成sign,转大写
String md5sign = Md5Util.getMD5(paramStr).toUpperCase();
//判断sign是否和生成的md5sign一致,一致则通过
if (!sign.equals(md5sign)) {
resultError("验证失败", httpServletResponse);
return false;
}
} catch (Exception e) {
log.error("校验报错:{}", e.getMessage());
resultError("格式不正确", httpServletResponse);
return false;
}
return true;
}
/**
* 返回错误提示
*
* @param errorMsg
* @param httpServletResponse
* @throws IOException
*/
public void resultError(String errorMsg, HttpServletResponse httpServletResponse) {
httpServletResponse.setCharacterEncoding("UTF-8");
httpServletResponse.setContentType("text/json; charset=UTF-8");
PrintWriter printWriter = null;
try {
printWriter = httpServletResponse.getWriter();
ResponseContent responseContent = new ResponseContent();
responseContent.setErrorCode(ErrorEnum.ERROR_10001.getValue());
responseContent.setCode(ErrorEnum.ERROR_10001.getValue());
responseContent.setRet(1);
responseContent.setMsg(errorMsg);
String resStr = JSON.toJSON(responseContent).toString();
printWriter.write(resStr);
printWriter.flush();
} catch (Exception e) {
log.error("CheckInterceptor.resultError异常:{}", e.getMessage());
e.printStackTrace();
} finally {
if (printWriter != null) {
printWriter.close();
}
}
}
}
WebAppConfig
package com.app.handler;
import com.beust.jcommander.internal.Lists;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.util.List;
/**
* 接口拦截
*
*/
@Configuration
public class WebAppConfig {
@Autowired
private CheckInterceptor checkInterceptor;
@Bean
public WebMvcConfigurer webMvcConfigurer() {
return new WebMvcConfigurer() {
/**
* addPathPatterns("/testInter/**)的意思是这个链接下的都要进入到里面去执行
* excludePathPatterns("/login")的意思是login的url可以不用进入到中,直接
* 放过执行。
* @param registry
*/
@Override
public void addInterceptors(InterceptorRegistry registry) {
List<String> excludePatterns = Lists.newArrayList();
//静态资源不校验
excludePatterns.add("/static/**");
registry.addInterceptor(checkInterceptor).addPathPatterns("/*/**").excludePathPatterns(excludePatterns);
}
};
}
}
3、获取accessToken接口
/**
* 生成accessToken
*
* @param appParam
* @return
*/
@ApiOperation("获取accessToken接口")
@PostMapping("/get_access_token")
public ResponseContent getAccessToken(@RequestBody @Validated @ApiParam(name = "授权码查询对象", value = "传入json") AppTokenParam appParam) {
HashMap map = Maps.newHashMap();
map.put("app_id", appParam.getAppId());
map.put("timestamp", appParam.getTimestamp());
String timestamp = appParam.getTimestamp();
String secret = appParam.getSecret();
App app = appService.findByAppId(appParam.getAppId());
if (app == null) {
return ResponseContent.error(ErrorConstants.ERROR_11005.getCode(), ErrorConstants.ERROR_11005.getMsg());
}
if (app.getDisabled() == 1) {
return ResponseContent.error(ErrorConstants.ERROR_11001.getCode(), ErrorConstants.ERROR_11001.getMsg());
}
//根据app_id + app_secret + timestamp进行MD5加密,大写
String md5String = MD5Utils.md5(app.getAppId() + app.getAppSecret() + timestamp).toUpperCase();
//对比加密后的MD5是否一致
if (!md5String.equals(secret)) {
return ResponseContent.error(EnumError.ERROR_50002.getCode(), EnumError.ERROR_50002.getMsg());
}
String accessToken = app.getAccessToken();
Map resMap = Maps.newHashMap();
//如果redis中存在token,则返回当前tonken,并返回过期时间
if (StringUtils.isNotBlank(accessToken)) {
if (baseRedisService.getString(accessToken) != null) {
resMap.put("access_token", accessToken);
resMap.put("expire", baseRedisService.getExpire(accessToken));
return ResponseContent.OK(resMap);
}
}
// 生成的新的accessToken
String newAccessToken = newAccessToken(app.getId(), app.getAppId(), app.getAppSecret());
// 表中更新当前accessToken
App updateApp = new App();
BeanCopyUtil.copy(app, updateApp);
updateApp.setAccessToken(newAccessToken);
//更新accessToken
appService.save(updateApp);
resMap.put("access_token", newAccessToken);
resMap.put("expire", baseRedisService.getExpire(newAccessToken));
return ResponseContent.OK(resMap);
}
private String newAccessToken(String id, String appId, String appSecret) {
// 使用appid+appsecret 生成对应的AccessToken 保存两个小时
String accessToken = TokenUtils.getAccessToken(appId, appSecret);
// 生成最新的token key为accessToken value 为 id
baseRedisService.setString(accessToken, id, Long.parseLong(resourceConfig.getHxTokenTime()));
return accessToken;
}
注意,拦截器中写了streamReader = new BufferedReader(new InputStreamReader(httpServletRequest.getInputStream(),获取参数后,@RequestBody 就不能再获取了,原因是,只能获取一次,获取完就不能再获取了,解决方案。写filter,每次进去前,先把内容存一遍 即可。
4、新建filter
import com.app.util.HttpContextUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
/**
* HttpServletRequestFilter
*
*/
@Component
@WebFilter(filterName = "HttpServletRequestFilter", urlPatterns = "/")
@Order(10000)
public class HttpServletRequestFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
ServletRequest requestWrapper = null;
if (servletRequest instanceof HttpServletRequest) {
requestWrapper = new RequestWrapper((HttpServletRequest) servletRequest);
}
//获取请求中的流如何,将取出来的字符串,再次转换成流,然后把它放入到新request对象中
// 在chain.doFiler方法中传递新的request对象
if (null == requestWrapper) {
filterChain.doFilter(servletRequest, servletResponse);
} else {
filterChain.doFilter(requestWrapper, servletResponse);
}
}
@Override
public void destroy() {
}
/***
* HttpServletRequest 包装器
* 解决: request.getInputStream()只能读取一次的问题
* 目标: 流可重复读
*/
public class RequestWrapper extends HttpServletRequestWrapper {
/**
* 请求体
*/
private String mBody;
public RequestWrapper(HttpServletRequest request) {
super(request);
// 将body数据存储起来
mBody = getBody(request);
}
/**
* 获取请求体
*
* @param request 请求
* @return 请求体
*/
private String getBody(HttpServletRequest request) {
return HttpContextUtils.getBodyString(request);
}
/**
* 获取请求体
*
* @return 请求体
*/
public String getBody() {
return mBody;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
// 创建字节数组输入流
final ByteArrayInputStream bais = new ByteArrayInputStream(mBody.getBytes(StandardCharsets.UTF_8));
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
return bais.read();
}
};
}
}
}
过滤器涉及到的util----》HttpContextUtils
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
/**
* HttpContextUtils
*
*/
public class HttpContextUtils {
/**
* 获取query参数
*
* @param request
* @return
*/
public static Map<String, String> getParameterMapAll(HttpServletRequest request) {
Enumeration<String> parameters = request.getParameterNames();
Map<String, String> params = new HashMap<>();
while (parameters.hasMoreElements()) {
String parameter = parameters.nextElement();
String value = request.getParameter(parameter);
params.put(parameter, value);
}
return params;
}
/**
* 获取请求Body
*
* @param request
* @return
*/
public static String getBodyString(ServletRequest request) {
StringBuilder sb = new StringBuilder();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = request.getInputStream();
reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
}