不用redis的简化版,不区分ip和接口
# -*- coding: utf-8 -*-
import time
from flask import Flask
app = Flask(__name__)
capacity = 10 # 桶容量
rate = 1 # 速率 每秒增加一个令牌
last_time = int(time.time()) # 上次访问时间
current_tokens = capacity # 当前令牌桶中令牌数量
def can_access():
global current_tokens
global last_time
now = int(time.time())
increase_tokens = (now - last_time) * rate
current_tokens = min(capacity, current_tokens + increase_tokens)
if current_tokens > 0:
current_tokens -= 1
last_time = int(time.time())
return True
else:
return False
@app.route('/')
def tokens_bucket():
if not can_access():
return '速率超限制'
return 'Hello, 令牌桶!'
if __name__ == '__main__':
app.run()
利用redis,区分ip和接口
# -*- coding: utf-8 -*-
import time
from flask import Flask
from flask import request
from redis import Redis
redis_client = Redis()
app = Flask(__name__)
capacity = 5 # 桶容量
rate = 1 # 速率 每秒增加一个令牌
def can_access(ip, func):
# 限制指定ip访问指定接口的速率,过ip和函数名确定key
func_name = func.__name__
redis_key = ip + func_name
now = int(time.time())
current_tokens = redis_client.hget(redis_key, 'current_tokens')
last_time = redis_client.hget(redis_key, 'last_time')
current_tokens = int(current_tokens) if current_tokens else capacity
last_time = int(last_time) if last_time else now
increase_tokens = (now - last_time) * rate # 增加的令牌桶
current_tokens = min(capacity, current_tokens + increase_tokens)
if current_tokens > 0:
redis_client.hset(redis_key, 'current_tokens', current_tokens - 1)
redis_client.hset(redis_key, 'last_time', int(time.time()))
return True
else:
return False
@app.route('/')
def tokens_bucket():
ip = request.remote_addr
if not can_access(ip, tokens_bucket):
return '当前ip:{}访问:{}接口速率超限制'.format(ip, tokens_bucket.__name__)
return 'Hello, 令牌桶!'
if __name__ == '__main__':
app.run()
更通用的装饰器版:
# -*- coding: utf-8 -*-
import time
from flask import Flask
from flask import request
from redis import Redis
redis_client = Redis()
app = Flask(__name__)
current_tokens_key = 'current_tokens'
last_time_key = 'last_time'
def can_access(rate=1, capacity=5):
"""
:param rate: 令牌桶添加速率,默认每秒1个
:param capacity: 令牌桶容量,默认5
:return:
"""
def wrapper(func):
def inner(*arg, **kwargs):
func_name = func.__name__
ip = request.remote_addr
hash_name = ip + func_name
now = int(time.time())
current_tokens = redis_client.hget(hash_name, current_tokens_key)
last_time = redis_client.hget(hash_name, last_time_key)
current_tokens = int(current_tokens) if current_tokens else capacity
last_time = int(last_time) if last_time else now
increase_tokens = (now - last_time) * rate # 增加的令牌桶
current_tokens = min(capacity, current_tokens + increase_tokens)
if current_tokens > 0:
redis_client.hset(hash_name, current_tokens_key, current_tokens - 1)
redis_client.hset(hash_name, last_time_key, int(time.time()))
return func(*arg, **kwargs)
else:
return '当前ip:{}访问:{}接口速率超限制'.format(ip, func_name)
return inner
return wrapper
@app.route('/')
@can_access(rate=1, capacity=10)
def tokens_bucket():
return 'Hello, 令牌桶!'
if __name__ == '__main__':
app.run()