防止sql注入sql中我们常用# 但是有些时候难免要用到拼接sql,这时候我们在后端就要进行参数的判断,综合网上查找的情况,记录一下.
注册拦截器
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class MyWedAppConfigurer implements WebMvcConfigurer {
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 多个拦截器组成一个拦截器链
// addPathPatterns 用于添加拦截规则
// excludePathPatterns 用户排除拦截
registry.addInterceptor(new MyInterceptor()).addPathPatterns("/**");
WebMvcConfigurer.super.addInterceptors(registry);
}
}
第二个实现自己的拦截器
import com.alibaba.fastjson.JSONObject;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.*;
import java.util.regex.Pattern;
@Component
public class MyInterceptor implements HandlerInterceptor {
private Pattern sqlPattern = Pattern.compile(
"\\b(and|exec|insert|drop|grant|alter|delete|update|count|chr|mid|master|truncate|char|declare|or|exec|having|sleep)\\b|^\\.{2}/|\\s\\|{2}\\s|\\s\\+\\s");// |('|%)S
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
boolean fang = true;
// String uri = request.getRequestURI();
//获取访问者的IP
String remoteAddr = request.getRemoteAddr();
remoteAddr = getIpAddsr(request);
if (request.getMethod().equalsIgnoreCase("post")) {
fang = checkPost(request);
}
if (request.getMethod().equalsIgnoreCase("get")) {
fang = checkGet(request);
}
if (!fang) {
code = "400";
message = "ParameterIllegal !";
}
}
if (!fang) {
PrintWriter writer = response.getWriter();
writer.write("{\n" +
" \"code\": " + code + ",\n" +
" \"msg\": " + message + ",\n" +
" \"data\": null\n" +
"}");
writer.flush();
writer.close();
}
return fang;
}
private boolean checkGet(HttpServletRequest request) {
return !request.getParameterMap().values().stream().anyMatch(values -> {
for (String v : values) {
if (sqlPattern.matcher(v).find()) {
LOG.error(v + " Illegal ");
return true;
}
}
return false;
});
}
private boolean checkPost(HttpServletRequest request) {
try {
Map reqMap = new HashMap<String, Object>();
String paramString = "";
MyRequestWrapper requestWrapper = new MyRequestWrapper(request);
String bodyString = requestWrapper.getBodyString();
if (StringUtil.isNotEmpty(bodyString)) {
Map postMap = JSONObject.parseObject(bodyString, Map.class);
paramString = JSONObject.toJSONString(reqMap);
reqMap.putAll(postMap);
}
boolean containKey = sqlPattern.matcher(paramString).find();
if (containKey) {
LOG.error(paramString + " Illegal ");
}
return !containKey;
} catch (Exception ex) {
return true;
}
}
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
}
}
第三个 将请求参数继续传下去
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
public class MyRequestWrapper extends HttpServletRequestWrapper {
private final byte[] body;
/**
* Constructs a request object wrapping the given request.
*
* @param request The request to wrap
* @throws IllegalArgumentException if the request is null
*/
public MyRequestWrapper(HttpServletRequest request) {
super(request);
String sessionStream = getBodyString(request);
body = sessionStream.getBytes(Charset.forName("UTF-8"));
}
public String getBodyString(){
return new String(body,Charset.forName("UTF-8"));
}
/**
* 获取请求Body
*
* @param request
* @return
*/
private String getBodyString(final ServletRequest request) {
StringBuilder sb = new StringBuilder();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = cloneInputStream(request.getInputStream());
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
}
catch (IOException e) {
e.printStackTrace();
}
finally {
if (inputStream != null) {
try {
inputStream.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
/**
* Description: 复制输入流
*
* @param inputStream
* @return
*/
public InputStream cloneInputStream(ServletInputStream inputStream) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
try {
while ((len = inputStream.read(buffer)) > -1) {
byteArrayOutputStream.write(buffer, 0, len);
}
byteArrayOutputStream.flush();
}
catch (IOException e) {
e.printStackTrace();
}
InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
return byteArrayInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
public boolean isFinished() {
return false;
}
public boolean isReady() {
return false;
}
public void setReadListener(ReadListener readListener) {
}
};
}
}
第四个 过滤器对请求处理
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
public class RequestFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletResponse res = (HttpServletResponse) servletResponse;
HttpServletRequest req = (HttpServletRequest) servletRequest;
String origin = req.getHeader("Origin");
if(!org.springframework.util.StringUtils.isEmpty(origin)) {
res.addHeader("Access-Control-Allow-Origin", origin);
}
res.addHeader("Access-Control-Allow-Methods", "*");
res.addHeader("Access-Control-Allow-Credentials", "true");
res.addHeader("Access-Control-Allow-Headers", "*");
if (req.getMethod().equals("OPTIONS")) {
res.setStatus(HttpServletResponse.SC_OK);
return;
}
String uri = ((HttpServletRequest) servletRequest).getMethod();
if("post".equalsIgnoreCase(uri)) {
//解决POST请求从stream只能获取一次数据问题
MyRequestWrapper requestWrapper = new MyRequestWrapper((HttpServletRequest) servletRequest);
filterChain.doFilter(requestWrapper, servletResponse);
}else{
filterChain.doFilter(servletRequest, servletResponse);
}
}
/**
* 请求uri是否在指定列表中
* @param curUri
* @return
*/
private boolean containRegisterUri(String curUri){
String[] arr = new String[]{
"这里是需要判断的接口"
};
List<String> urls = Arrays.asList(arr);
for(String url : urls){
if(curUri.contains(url)){
return true;
}
}
return false;
}
@Override
public void destroy() {
}
}
第五个在启动类中添加过滤器
@Bean
public FilterRegistrationBean httpServletRequestReplacedRegistration() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(new RequestFilter());
registration.addUrlPatterns("/*");
registration.addInitParameter("paramName", "paramValue");
registration.setName("efmRequestFilter");
registration.setOrder(1);
return registration;
}