节流的源码流程
dispatch中调用 self.check_throttles(request)
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
#获取节流类
def get_throttles(self):
"""
Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
#异常抛出
def throttled(self, request, wait):
"""
If request is throttled, determine what kind of exception to raise.
"""
raise exceptions.Throttled(wait)
#限速的全局配置
'DEFAULT_THROTTLE_CLASSES':[]
SimpleRateThrottle
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
# 访问频率设置的方法
'DEFAULT_THROTTLE_RATES' = {
'':''
}
使用
1.自定义一个限制类
VISIT_RECORD = {}
# 自定义限制
class MyThrottle(object):
def __init__(self):
self.history = None
def allow_request(self, request, view):
"""
自定义频率限制60秒内只能访问三次
"""
# 获取用户IP
ip = request.META.get("REMOTE_ADDR")
timestamp = time.time()
if ip not in VISIT_RECORD:
VISIT_RECORD[ip] = [timestamp, ]
return True
history = VISIT_RECORD[ip]
self.history = history
history.insert(0, timestamp)
while history and history[-1] < timestamp - 60:
history.pop()
if len(history) > 3:
return False
else:
return True
def wait(self):
"""
限制时间还剩多少
"""
timestamp = time.time()
return 60 - (timestamp - self.history[-1])
2.视图级别限制设置
class CommentViewSet(ModelViewSet):
queryset = models.Comment.objects.all()
serializer_class = app01_serializers.CommentSerializer
throttle_classes = [MyThrottle, ]
3.全局级别限制设置
#在settings.py中设置rest framework相关配置项
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
"DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
"DEFAULT_THROTTLE_CLASSES": ["app01.utils.MyThrottle", ]
}
4.内置限制类
from rest_framework.throttling import SimpleRateThrottle
class VisitThrottle(SimpleRateThrottle):
scope = "xxx"
def get_cache_key(self, request, view):
return self.get_ident(request)
案例
#访问频率的代码:
class VisitThrottles(object):
def __init__(self,history):
self.history = []
def allow_requests(self,request,view):
remote_addr = request.META.get('REMOTE_ADDR') #获取IP地址的方法
ctime = time.time()
if remote_addr not in visit_record:
visit_record[remote_addr] = [ctime,]
return True
history = visit_record.get(remote_addr)
self.history = history
while history and history[-1] < ctime - 60:
history.pop()
if len(history) < 3:
history.insert(0,ctime)
return True
return True
def wait(self):
ctime = time.time()
return 60-(ctime-self.history[-1])