go gin中间件开发

go gin中间件开发

我开发了检查request请求中的参数(包含get post和json参数)sql注入检查,和token检查

  • 我先说一下思路
    1. 在中间件中,获得request,取出其中你要检查或过滤的参数
    2. token检查,我有一张uid->token的数据表,用获得的token去查询数据库,检查是否存在
    3. sql注入检查,使用正则表达式匹配每一个参数,注意json参数和postget请求的参数,获取方式不一样,具体看代码
    4. 做一个slice和map,在匹配到slice或map中的参数或url时,跳过检查

token检查

  • 需要过滤的url和参数
var skipUrlForTokenArr = []string{
	"/test",
	"/public/",
	"/user/login",
	"/user/register",
	"/user/token/refresh",
	"/user/wxopendata/decode",
	"/user/wxserver/connect",
}

var maybeCheckForTokenArr = []string{
	"/user/homepage",
}

var skipParamsForSQLInjectMap = map[string]int{
	"file":           1,
	"encrypt_data":   1,
	"openid":         1,
	"iv":             1,
	"name":           1,
	"link":           1,
	"head_pic_link":  1,
	"pic_link":       1,
	"id_front_pic":   1,
	"id_reverse_pic": 1,
	"education_pic":  1,

  • 检查token并获得token对应的user信息

type User struct {
	Uid      int
	Level    int
	Phone    string
	Nickname string
	Sex      int
}

func CheckToken() gin.HandlerFunc {
	return func(c *gin.Context) {
		for _, url := range skipUrlForTokenArr {
			if strings.Contains(c.Request.URL.Path, url) {
				c.Next()
				return
			}
		}

		token := c.GetHeader("token")
		if token == "" {
			for _, url := range maybeCheckForTokenArr {
				if strings.Contains(c.Request.URL.Path, url) {
					c.Next()
					return
				}
			}
			err := errors.New("invalid token")
			tools.Mlog.Error(err.Error())
			c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
			c.Abort()
			return
		}
		fmt.Printf("token = %s\n", token)
		if tools.FilteredSQLInject(token) {
			err := fmt.Errorf("sql注入攻击 %s", token)
			tools.Mlog.Error(err.Error())
			c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
			c.Abort()
			return
		}
		//检查token是否存在
		uid := -1
		var tokenCreateAt time.Time
		row := db.DbApp.QueryRow(`SELECT uid, token_create_at FROM user_token WHERE token = ?;`, token)
		err := row.Scan(&uid, &tokenCreateAt)
		if err != nil && err != sql.ErrNoRows {
			tools.Mlog.Error(err.Error())
			c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
			c.Abort()
		}
		if uid == -1 {
			c.JSON(http.StatusOK, code.RetTokenIsntExists)
			c.Abort()
			return
		}
		//token是否过期
		//m, _ := time.ParseDuration(fmt.Sprintf("%dm", setting.TokenConf.TokenExp))
		//if tokenCreateAt.Add(m).Before(time.Now()) {
		//	err = errors.New("token is expired")
		//	c.JSON(http.StatusUnauthorized, code.GeneralErrRet(err))
		//	c.Abort()
		//	return
		//}

		var u User
		row = db.DbApp.QueryRow(`SELECT level, phone, nickname, sex FROM user WHERE uid = ?;`, uid)
		u.Uid = uid
		err = row.Scan(&u.Level, &u.Phone, &u.Nickname, &u.Sex)
		if err != nil {
			tools.Mlog.Error(err.Error())
			c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
			c.Abort()
			return
		}
		c.Set("uid", u.Uid)
		c.Set("user", u)
		c.Next()
	}

检查sql注入

  • 临时copy一份request出来,这里不能读取原装request中的数据,因为body读取过一次后,它会自动关闭true代表body不关闭
func CheckSQLInject() gin.HandlerFunc {
	return func(c *gin.Context) {
		body, err := httputil.DumpRequest(c.Request, true)
		if err != nil {
			err = fmt.Errorf("parse request body failed, err: %s", err)
			tools.Mlog.Error(err.Error())
			c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
			c.Abort()
			return
		}
		//fmt.Printf("%s\n", string(body))
		//fmt.Printf("%#v\n", c.Request)
		if bytes.Index(body, []byte("application/json")) != -1 { //检查json参数
			//过滤http头部,获取body
			body = body[bytes.Index(body, []byte("User-Agent")):]
			index := bytes.Index(body, []byte("{"))
			if index != -1 { //判读是否找到body
				body = body[index:]
			} else {
				body = nil
			}
			if len(body) != 0 { //检查request json参数
				fmt.Println(string(body))
				m := make(map[string]interface{})
				err = json.Unmarshal(body, &m)
				if err != nil {
					err = fmt.Errorf("parse request body failed, err: %s", err.Error())
					tools.Mlog.Error(err.Error())
					c.JSON(http.StatusInternalServerError, code.GeneralErrRet(err))
					c.Abort()
					return
				}
				for k, v := range m {
					if skipParamsForSQLInjectMap[k] == 1 {
						continue
					}
					if reflect.TypeOf(v).String() == "string" {
						if tools.FilteredSQLInject(v.(string)) {
							err := fmt.Errorf("sql注入攻击 %s", v)
							tools.Mlog.Error(err.Error())
							c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
							c.Abort()
							return
						}
					}
				}
			}
		}
		if c.Request.Form == nil { //检查get和post中的参数
			c.Request.ParseMultipartForm(32 << 20)
		}
		for k, arr := range c.Request.Form {
			if c.Request.Method != http.MethodGet {
				fmt.Printf("%s=%v&", k, arr)
			}
			if skipParamsForSQLInjectMap[k] == 1 {
				continue
			}
			for _, v := range arr {
				if tools.FilteredSQLInject(v) {
					err := errors.New("sql注入攻击")
					tools.Mlog.Error(err.Error())
					c.JSON(http.StatusBadRequest, code.GeneralErrRet(err))
					c.Abort()
					return
				}

			}
		}
		c.Next()
	}
}

使用

router.Use(middleware.CheckToken(), middleware.CheckSQLInject())
阅读更多
换一批

没有更多推荐了,返回首页