package rate_limit
import (
"fmt"
"time"
"github.com/go-redis/redis"
)
type RateLimitParam struct {
Threhold int64
Period time.Duration
}
var client *redis.Client
func RateLimitCheck(key string, limits ...RateLimitParam) (ok bool, err error) {
for _, limit := range limits {
realKey := fmt.Sprintf("rate_limit:%s:%d", key, limit.Period)
ok, err = rateLimitCheck(realKey, limit)
if !ok {
return
}
}
return true, nil
}
func rateLimitCheck(realKey string, limit RateLimitParam) (ok bool, err error) {
cmd := client.SetNX(realKey, 1, limit.Period)
if cmd.Err() != nil {
return true, cmd.Err() //bypass this if error happens
}
if cmd.Val() { //first time
return true, nil
}
incrCmd := client.Incr(realKey)
if incrCmd.Err() != nil {
return true, cmd.Err() //bypass this if error happens
}
if incrCmd.Val() > limit.Threhold {
return false, nil
}
return true, nil
}
func Init(addr string) {
client = redis.NewClient(&redis.Options{
Addr: addr,
Password: "", // no password set
DB: 0, // use default DB
Network: "tcp",
PoolSize: 10,
MinIdleConns: 5,
DialTimeout: 100 * time.Millisecond,
ReadTimeout: 500 * time.Millisecond,
WriteTimeout: 500 * time.Millisecond,
IdleTimeout: 10 * time.Second,
})
if _, err := client.Ping().Result(); err != nil {
panic(err)
}
}