package com.qypt.base.handle;
import cn.hutool.core.exceptions.UtilException;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.util.StreamUtils;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
/**
* SQL 注入过滤器
*
* @Description
* @Author GDKY
* @Date 2024/8/21 19:44
**/
@Component
public class SqlInjectionFilter implements Filter {
/**
* 定义常用的 sql关键字
*/
public static String SQL_REGEX = "and |extractvalue|updatexml|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |+|user()";
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
if (!httpRequest.getMethod().equalsIgnoreCase("POST") ||
!MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(httpRequest.getContentType())) {
chain.doFilter(request, response);
return;
}
// 读取请求体
String requestBody = StreamUtils.copyToString(httpRequest.getInputStream(), StandardCharsets.UTF_8);
JSONObject jsonObject = JSON.parseObject(requestBody);
if (jsonObject == null) {
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json; charset=utf-8");
PrintWriter out = response.getWriter();
JSONObject res = new JSONObject();
res.put("msg", "参数不能为空");
res.put("success", "false");
out.append(res.toString());
return;
}
String orderParams = jsonObject.getString("orderParams");
String searchParams = jsonObject.getString("searchParams");
String searchParams2 = jsonObject.getString("searchParams2");
if (StringUtils.isNotBlank(orderParams) || StringUtils.isNotBlank(searchParams) || StringUtils.isNotBlank(searchParams2)) {
try {
this.filterKeyword(orderParams);
this.filterKeyword(searchParams);
this.filterKeyword(searchParams2);
// 创建一个新的请求对象,其中包含了清理后的请求体
HttpServletRequest wrappedRequest = new WrappedHttpServletRequest(httpRequest, requestBody);
chain.doFilter(wrappedRequest, response);
} catch (Exception e) {
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json; charset=utf-8");
PrintWriter out = response.getWriter();
JSONObject res = new JSONObject();
res.put("msg", "参数存在SQL注入风险");
res.put("success", "false");
out.append(res.toString());
}
} else {
// 创建一个新的请求对象,其中包含了清理后的请求体
HttpServletRequest wrappedRequest = new WrappedHttpServletRequest(httpRequest, requestBody);
chain.doFilter(wrappedRequest, response);
}
}
/**
* SQL关键字检查
*/
public void filterKeyword(String value) {
if (StringUtils.isEmpty(value)) {
return;
}
String[] sqlKeywords = StringUtils.split(SQL_REGEX, "\\|");
for (String sqlKeyword : sqlKeywords) {
if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1) {
throw new UtilException("参数存在SQL注入风险");
}
}
}
}
```java
package com.qypt.base.handle;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
public class WrappedHttpServletRequest extends HttpServletRequestWrapper {
private final String requestBody;
public WrappedHttpServletRequest(HttpServletRequest request, String requestBody) {
super(request);
this.requestBody = requestBody;
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(requestBody.getBytes());
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
};
}
@Override
public String getParameter(String name) {
return super.getParameter(name);
}
}
SqlInjectionFilter 类
这个类实现了 Filter 接口,用于在请求到达控制器之前处理请求体中的参数,以防止 SQL 注入攻击。
主要逻辑
请求体读取:
使用 StreamUtils.copyToString 方法读取请求体中的内容。
JSON 解析:
使用 JSON.parseObject 将请求体内容解析为 JSONObject 对象。
参数检查:
检查 orderParams、searchParams 和 searchParams2 是否存在,并调用 filterKeyword 方法检查这些参数是否存在 SQL 注入风险。
响应处理:
如果参数为空或存在 SQL 注入风险,则返回错误响应。
如果参数安全,则创建一个包装的请求对象 WrappedHttpServletRequest 并继续处理请求。
filterKeyword 方法
这个方法用于检查给定的字符串 value 是否包含 SQL 注入关键字。如果存在,则抛出 UtilException 异常。
WrappedHttpServletRequest 类
这个类继承自 HttpServletRequestWrapper,用于包装原始的 HttpServletRequest 对象,并覆盖 getInputStream 方法,以便在请求体被修改后仍能正确处理。
主要逻辑
输入流重写:
创建一个 ByteArrayInputStream 来读取修改后的请求体。
创建一个新的 ServletInputStream 来包装 ByteArrayInputStream。
其他方法覆盖:
getParameter 方法保持不变,使用父类的实现。