go gin中间件开发
我开发了检查request请求中的参数(包含get post和json参数)sql注入检查,和token检查
- 我先说一下思路
- 在中间件中,获得request,取出其中你要检查或过滤的参数
- token检查,我有一张uid->token的数据表,用获得的token去查询数据库,检查是否存在
- sql注入检查,使用正则表达式匹配每一个参数,注意json参数和postget请求的参数,获取方式不一样,具体看代码
- 做一个slice和map,在匹配到slice或map中的参数或url时,跳过检查
token检查
var skipUrlForTokenArr = []string{
"/test",
"/public/",
"/user/login",
"/user/register",
"/user/token/refresh",
"/user/wxopendata/decode",
"/user/wxserver/connect",
}
var maybeCheckForTokenArr = []string{
"/user/homepage",
}
var skipParamsForSQLInjectMap = map[string]int{
"file": 1,
"encrypt_data": 1,
"openid": 1,
"iv": 1,
"name": 1,
"link": 1,
"head_pic_link": 1,
"pic_link": 1,
"id_front_pic": 1,
"id_reverse_pic": 1,
"education_pic": 1,
type User struct {
Uid int
Level int
Phone string
Nickname string
Sex int
}
func CheckToken() gin.HandlerFunc {
return func(c *gin.Context) {
for _, url := range skipUrlForTokenArr {
if strings.Contains(c.Request.URL.Path, url) {
c.Next()
return
}
}
token := c.GetHeader("token")
if token == "" {
for _, url := range maybeCheckForTokenArr {
if strings.Contains(c.Request.URL.Path, url) {
c.Next()
return
}
}
err := errors.New("invalid token")
tools.Mlog.Error(err.Error())
c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
c.Abort()
return
}
fmt.Printf("token = %s\n", token)
if tools.FilteredSQLInject(token) {
err := fmt.Errorf("sql注入攻击 %s", token)
tools.Mlog.Error(err.Error())
c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
c.Abort()
return
}
uid := -1
var tokenCreateAt time.Time
row := db.DbApp.QueryRow(`SELECT uid, token_create_at FROM user_token WHERE token = ?;`, token)
err := row.Scan(&uid, &tokenCreateAt)
if err != nil && err != sql.ErrNoRows {
tools.Mlog.Error(err.Error())
c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
c.Abort()
}
if uid == -1 {
c.JSON(http.StatusOK, code.RetTokenIsntExists)
c.Abort()
return
}
var u User
row = db.DbApp.QueryRow(`SELECT level, phone, nickname, sex FROM user WHERE uid = ?;`, uid)
u.Uid = uid
err = row.Scan(&u.Level, &u.Phone, &u.Nickname, &u.Sex)
if err != nil {
tools.Mlog.Error(err.Error())
c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
c.Abort()
return
}
c.Set("uid", u.Uid)
c.Set("user", u)
c.Next()
}
检查sql注入
- 临时copy一份request出来,这里不能读取原装request中的数据,因为body读取过一次后,它会自动关闭true代表body不关闭
func CheckSQLInject() gin.HandlerFunc {
return func(c *gin.Context) {
body, err := httputil.DumpRequest(c.Request, true)
if err != nil {
err = fmt.Errorf("parse request body failed, err: %s", err)
tools.Mlog.Error(err.Error())
c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
c.Abort()
return
}
if bytes.Index(body, []byte("application/json")) != -1 {
body = body[bytes.Index(body, []byte("User-Agent")):]
index := bytes.Index(body, []byte("{"))
if index != -1 {
body = body[index:]
} else {
body = nil
}
if len(body) != 0 {
fmt.Println(string(body))
m := make(map[string]interface{})
err = json.Unmarshal(body, &m)
if err != nil {
err = fmt.Errorf("parse request body failed, err: %s", err.Error())
tools.Mlog.Error(err.Error())
c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
c.Abort()
return
}
for k, v := range m {
if skipParamsForSQLInjectMap[k] == 1 {
continue
}
if reflect.TypeOf(v).String() == "string" {
if tools.FilteredSQLInject(v.(string)) {
err := fmt.Errorf("sql注入攻击 %s", v)
tools.Mlog.Error(err.Error())
c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
c.Abort()
return
}
}
}
}
}
if c.Request.Form == nil {
c.Request.ParseMultipartForm(32 << 20)
}
for k, arr := range c.Request.Form {
if c.Request.Method != http.MethodGet {
fmt.Printf("%s=%v&", k, arr)
}
if skipParamsForSQLInjectMap[k] == 1 {
continue
}
for _, v := range arr {
if tools.FilteredSQLInject(v) {
err := errors.New("sql注入攻击")
tools.Mlog.Error(err.Error())
c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
c.Abort()
return
}
}
}
c.Next()
}
}
使用
router.Use(middleware.CheckToken(), middleware.CheckSQLInject())